seriesClassification.cpp 2.09 KB
Newer Older
1
2
3
//
//  seriesClassification.cpp
//
mzed's avatar
mzed committed
4
//  Created by Michael Zbyszynski on 08/06/2017.
5
6
7
8
9
10
11
12
13
14
15
16
17
//  Copyright © 2017 Goldsmiths. All rights reserved.
//

#include <vector>
#include "seriesClassification.h"
#ifdef EMSCRIPTEN
#include "emscripten/seriesClassificationEmbindings.h"
#endif

seriesClassification::seriesClassification() {};

seriesClassification::~seriesClassification() {};

18
bool seriesClassification::train(const std::vector<trainingSeries> &seriesSet) {
mzed's avatar
mzed committed
19
20
    reset();
    bool trained = true;
mzed's avatar
mzed committed
21
    allTrainingSeries = seriesSet;
mzed's avatar
mzed committed
22
23
24
25
26
27
28
29
30
    minLength = maxLength = int(allTrainingSeries[0].input.size());
    for (int i = 0; i < allTrainingSeries.size(); ++i) {
        if (allTrainingSeries[i].input.size() < minLength) {
            minLength = int(allTrainingSeries[i].input.size());
        }
        if (allTrainingSeries[i].input.size() > maxLength) {
            maxLength = int(allTrainingSeries[i].input.size());
        }
    }
mzed's avatar
mzed committed
31

mzed's avatar
mzed committed
32
    //TODO: calculate some size statistics here?
mzed's avatar
mzed committed
33
34
    //min length per label
    //max length per label
mzed's avatar
mzed committed
35
    
mzed's avatar
mzed committed
36
37
38
    return trained;
};

mzed's avatar
mzed committed
39
void seriesClassification::reset() {
mzed's avatar
mzed committed
40
41
    allCosts.clear();
    allTrainingSeries.clear();
42
43
}

44
std::string seriesClassification::run(const std::vector<std::vector<double>> &inputSeries) {
mzed's avatar
mzed committed
45
    fastDTW fastDtw;
mzed's avatar
mzed committed
46
    int searchRadius = 1; //TODO: Define this properly, elsewhere?
47
    int closestSeries = 0;
mzed's avatar
mzed committed
48
    allCosts.clear();
mzed's avatar
mzed committed
49
    double lowestCost = fastDtw.getCost(inputSeries, allTrainingSeries[0].input, searchRadius);
mzed's avatar
mzed committed
50
    allCosts.push_back(lowestCost);
mzed's avatar
mzed committed
51
52
    
    for (int i = 1; i < allTrainingSeries.size(); ++i) {
mzed's avatar
mzed committed
53
        double currentCost = fastDtw.getCost(inputSeries, allTrainingSeries[i].input, searchRadius);
mzed's avatar
mzed committed
54
        allCosts.push_back(currentCost);
55
56
57
58
59
        if (currentCost < lowestCost) {
            lowestCost = currentCost;
            closestSeries = i;
        }
    }
mzed's avatar
mzed committed
60
    return allTrainingSeries[closestSeries].label;
mzed's avatar
mzed committed
61
62
};

mzed's avatar
mzed committed
63

64
65
66
67
std::vector<double> seriesClassification::getCosts() {
    return allCosts;
}

mzed's avatar
mzed committed
68
69
70
71
72
//
//std::vector<double> seriesClassification::getCosts(const std::vector<trainingExample> &trainingSet) {
//    run(trainingSet);
//    return allCosts;
//}