seriesClassification.cpp 3.16 KB
Newer Older
1
2
3
//
//  seriesClassification.cpp
//
mzed's avatar
mzed committed
4
//  Created by Michael Zbyszynski on 08/06/2017.
5
6
7
8
9
10
11
12
13
14
15
16
17
//  Copyright © 2017 Goldsmiths. All rights reserved.
//

#include <vector>
#include "seriesClassification.h"
#ifdef EMSCRIPTEN
#include "emscripten/seriesClassificationEmbindings.h"
#endif

seriesClassification::seriesClassification() {};

seriesClassification::~seriesClassification() {};

mzed's avatar
mzed committed
18
bool seriesClassification::addSeries(const std::vector<std::vector<double>> &newSeries) {
19
20
21
22
23
24
    dtw newDTW;
    newDTW.setSeries(newSeries);
    dtwClassifiers.push_back(newDTW);
    return true;
}

mzed's avatar
mzed committed
25
bool seriesClassification::addSeries(const std::vector<trainingExample> &trainingSet) {
mzed's avatar
mzed committed
26
27
28
29
30
31
32
    std::vector<std::vector<double>> newSeries;
    for (int i = 0; i < trainingSet.size(); ++i) {
        newSeries.push_back(trainingSet[i].input);
    }
    return addSeries(newSeries);
};

mzed's avatar
mzed committed
33
34
35
36
///////////////////////////////////////////////// Training
//TODO: Refactor these

bool seriesClassification::train(const std::vector<std::vector<std::vector<double> > > &vectorSet) {
mzed's avatar
mzed committed
37
38
    bool trained = true;
    reset();
mzed's avatar
mzed committed
39
40
    for (int i = 0; i < vectorSet.size(); ++i) {
        if (!addSeries(vectorSet[i])) {
mzed's avatar
mzed committed
41
42
43
44
45
46
            trained = false;
        };
    }
    return trained;
}

mzed's avatar
mzed committed
47
bool seriesClassification::train(const std::vector<std::vector<trainingExample> > &exampleSet) {
mzed's avatar
mzed committed
48
49
    bool trained = true;
    reset();
mzed's avatar
mzed committed
50
51
    for (int i = 0; i < exampleSet.size(); ++i) {
        if (!addSeries(exampleSet[i])) {
mzed's avatar
mzed committed
52
53
54
55
56
57
            trained = false;
        };
    }
    return trained;
}

mzed's avatar
mzed committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
bool seriesClassification::trainLabel(const std::vector<trainingSeries> &seriesSet) {
    bool trained = true;
    reset();
    for (int i = 0; i < seriesSet.size(); ++i) {
        if(!addSeries(seriesSet[i].input) ) {
            trained = false;
        }
        labels.push_back(seriesSet[i].label);
    }
    return trained;
};

/////////////////////////////////////////////////

mzed's avatar
mzed committed
72
void seriesClassification::reset() {
mzed's avatar
mzed committed
73
    labels.clear();
74
75
76
    dtwClassifiers.clear();
}

mzed's avatar
mzed committed
77
int seriesClassification::run(const std::vector<std::vector<double>> &inputSeries) {
78
79
    //TODO: check vector sizes and reject bad data
    int closestSeries = 0;
mzed's avatar
mzed committed
80
    allCosts.clear();
mzed's avatar
mzed committed
81
    double lowestCost = dtwClassifiers[0].run(inputSeries);
mzed's avatar
mzed committed
82
83
    allCosts.push_back(lowestCost);
    for (int i = 1; i < dtwClassifiers.size(); ++i) {
mzed's avatar
mzed committed
84
        double currentCost = dtwClassifiers[i].run(inputSeries);
mzed's avatar
mzed committed
85
        allCosts.push_back(currentCost);
86
87
88
89
90
91
        if (currentCost < lowestCost) {
            lowestCost = currentCost;
            closestSeries = i;
        }
    }
    return closestSeries;
mzed's avatar
mzed committed
92
93
};

mzed's avatar
mzed committed
94
int seriesClassification::run(const std::vector<trainingExample> &trainingSet) {
mzed's avatar
mzed committed
95
96
97
98
    std::vector<std::vector<double>> newSeries;
    for (int i = 0; i < trainingSet.size(); ++i) {
        newSeries.push_back(trainingSet[i].input);
    }
mzed's avatar
mzed committed
99
    return run(newSeries);
mzed's avatar
mzed committed
100
101
};

mzed's avatar
mzed committed
102
103
104
105
std::string seriesClassification::runLabel(const std::vector<std::vector<double>> &inputSeries) {
    return labels[run(inputSeries)];
};

mzed's avatar
mzed committed
106
107
108
109
110
std::vector<double> seriesClassification::getCosts() {
    return allCosts;
}

std::vector<double> seriesClassification::getCosts(const std::vector<trainingExample> &trainingSet) {
mzed's avatar
mzed committed
111
    run(trainingSet);
mzed's avatar
mzed committed
112
113
    return allCosts;
}