seriesClassification.cpp 6.12 KB
Newer Older
1
2
//
//  seriesClassification.cpp
mzed's avatar
mzed committed
3
//  RapidLib
4
//
mzed's avatar
mzed committed
5
//  Created by Michael Zbyszynski on 08/06/2017.
6
7
8
9
//  Copyright © 2017 Goldsmiths. All rights reserved.
//

#include <vector>
mzed's avatar
mzed committed
10
#include <cassert>
11
12
13
14
15
#include "seriesClassification.h"
#ifdef EMSCRIPTEN
#include "emscripten/seriesClassificationEmbindings.h"
#endif

16
17
#define SEARCH_RADIUS 1

18
19
20
21
seriesClassification::seriesClassification() {};

seriesClassification::~seriesClassification() {};

22
bool seriesClassification::train(const std::vector<trainingSeries<double> > &seriesSet) {
mzed's avatar
mzed committed
23
    assert(seriesSet.size() > 0);
mzed's avatar
mzed committed
24
25
    reset();
    bool trained = true;
mzed's avatar
mzed committed
26
    allTrainingSeries = seriesSet;
mzed's avatar
mzed committed
27
28
    minLength = maxLength = int(allTrainingSeries[0].input.size());
    for (int i = 0; i < allTrainingSeries.size(); ++i) {
mzed's avatar
mzed committed
29
30
31
32
        //Global
        int newLength = int(allTrainingSeries[i].input.size());
        if (newLength < minLength) {
            minLength = newLength;
mzed's avatar
mzed committed
33
        }
mzed's avatar
mzed committed
34
35
36
37
        if (newLength > maxLength) {
            maxLength = newLength;
        }
        //Per Label
38
        std::map<std::string, minMax<int> >::iterator it = lengthsPerLabel.find(allTrainingSeries[i].label);
mzed's avatar
mzed committed
39
40
41
42
43
44
45
46
47
        if (it != lengthsPerLabel.end()) {
            int newLength = int(allTrainingSeries[i].input.size());
            if (newLength < it->second.min) {
                it->second.min = newLength;
            }
            if (newLength > it->second.max) {
                it->second.max = newLength;
            }
        } else {
48
            minMax<int> tempLengths;
mzed's avatar
mzed committed
49
50
            tempLengths.min = tempLengths.max = int(allTrainingSeries[i].input.size());
            lengthsPerLabel[allTrainingSeries[i].label] = tempLengths;
mzed's avatar
mzed committed
51
52
        }
    }
mzed's avatar
mzed committed
53
54
55
    return trained;
};

mzed's avatar
mzed committed
56
void seriesClassification::reset() {
mzed's avatar
mzed committed
57
58
    allCosts.clear();
    allTrainingSeries.clear();
mzed's avatar
mzed committed
59
60
61
    lengthsPerLabel.clear();
    minLength = -1;
    maxLength = -1;
62
63
}

64
std::string seriesClassification::run(const std::vector<std::vector<double>> &inputSeries) {
65
    int closestSeries = 0;
mzed's avatar
mzed committed
66
    allCosts.clear();
mzed's avatar
mzed committed
67
    double lowestCost = fastDTW<double>::getCost(inputSeries, allTrainingSeries[0].input, SEARCH_RADIUS);
mzed's avatar
mzed committed
68
    allCosts.push_back(lowestCost);
mzed's avatar
mzed committed
69
70
    
    for (int i = 1; i < allTrainingSeries.size(); ++i) {
mzed's avatar
mzed committed
71
        double currentCost = fastDTW<double>::getCost(inputSeries, allTrainingSeries[i].input, SEARCH_RADIUS);
mzed's avatar
mzed committed
72
        allCosts.push_back(currentCost);
73
74
75
76
77
        if (currentCost < lowestCost) {
            lowestCost = currentCost;
            closestSeries = i;
        }
    }
mzed's avatar
mzed committed
78
    return allTrainingSeries[closestSeries].label;
mzed's avatar
mzed committed
79
80
};

mzed's avatar
mzed committed
81

82
83
84
85
std::vector<double> seriesClassification::getCosts() {
    return allCosts;
}

mzed's avatar
mzed committed
86
87
88
89
90
91
int seriesClassification::getMinLength() {
    return minLength;
}

int seriesClassification::getMinLength(std::string label) {
    int labelMinLength = -1;
92
    std::map<std::string, minMax<int> >::iterator it = lengthsPerLabel.find(label);
mzed's avatar
mzed committed
93
94
95
96
97
98
99
100
101
102
103
104
    if (it != lengthsPerLabel.end()) {
        labelMinLength = it->second.min;
    }
    return labelMinLength;
}

int seriesClassification::getMaxLength() {
    return maxLength;
}

int seriesClassification::getMaxLength(std::string label) {
    int labelMaxLength = -1;
105
    std::map<std::string, minMax<int> >::iterator it = lengthsPerLabel.find(label);
mzed's avatar
mzed committed
106
107
108
109
110
111
    if (it != lengthsPerLabel.end()) {
        labelMaxLength = it->second.max;
    }
    return labelMaxLength;
}

112
113
114
115
116
117
118
119
120
121
122
seriesClassification::minMax<double> seriesClassification::calculateCosts(std::string label) {
    minMax<double> calculatedMinMax;
    calculatedMinMax.min = std::numeric_limits<double>::max();
    calculatedMinMax.max = std::numeric_limits<double>::min();
    int numSeries = 0;
    
    for (int i = 0; i < (allTrainingSeries.size() - 1); ++i) { //these loops are a little different than the two-label case
        if (allTrainingSeries[i].label == label) {
            for (int j = (i + 1); j < allTrainingSeries.size(); ++j) {
                if (allTrainingSeries[j].label == label) {
                    numSeries++;
mzed's avatar
mzed committed
123
                    double currentCost = fastDTW<double>::getCost(allTrainingSeries[i].input, allTrainingSeries[j].input, SEARCH_RADIUS);
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
                    if (numSeries == 1) {
                        calculatedMinMax.min = calculatedMinMax.max = currentCost; //first match is both min and max
                    } else {
                        if (currentCost < calculatedMinMax.min) {
                            calculatedMinMax.min = currentCost;
                        }
                        if (currentCost > calculatedMinMax.max) {
                            calculatedMinMax.max = currentCost;
                        }
                    }
                }
            }
        }
    }
    if (numSeries == 0) {
        calculatedMinMax.min = calculatedMinMax.max = 0;
    }
    return calculatedMinMax;
}

seriesClassification::minMax<double> seriesClassification::calculateCosts(std::string label1, std::string label2) {
    minMax<double> calculatedMinMax;
    calculatedMinMax.min = std::numeric_limits<double>::max();
    calculatedMinMax.max = std::numeric_limits<double>::min();
    int numSeries = 0;
    
    for (int i = 0; i < (allTrainingSeries.size()); ++i) {
        if (allTrainingSeries[i].label == label1) {
            for (int j = 0; j < allTrainingSeries.size(); ++j) {
                if (allTrainingSeries[j].label == label2) {
                    numSeries++;
mzed's avatar
mzed committed
155
                    double currentCost = fastDTW<double>::getCost(allTrainingSeries[i].input, allTrainingSeries[j].input, SEARCH_RADIUS);
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
                    if (numSeries == 1) {
                        calculatedMinMax.min = calculatedMinMax.max = currentCost; //first match is both min and max
                    } else {
                        if (currentCost < calculatedMinMax.min) {
                            calculatedMinMax.min = currentCost;
                        }
                        if (currentCost > calculatedMinMax.max) {
                            calculatedMinMax.max = currentCost;
                        }
                    }
                }
            }
        }
    }
    return calculatedMinMax;
}


mzed's avatar
mzed committed
174
175
176
177
178
//
//std::vector<double> seriesClassification::getCosts(const std::vector<trainingExample> &trainingSet) {
//    run(trainingSet);
//    return allCosts;
//}