Commit 6f4ab7a3 authored by mzed's avatar mzed
Browse files

more stats in series classification

parent 2089474d
......@@ -237,13 +237,18 @@ int main(int argc, const char * argv[]) {
seriesVector.push_back(tempSeries);
myDTW.train(seriesVector);
std::cout << "dtwrun " << myDTW.run(seriesOne) << std::endl;
std::cout << "dtwrun " << myDTW.run(seriesTwo) << std::endl;
assert(myDTW.run(seriesOne) == "first series");
assert(myDTW.run(seriesTwo) == "second series");
std::cout << myDTW.getCosts()[0] << std::endl;
std::cout << myDTW.getCosts()[1] << std::endl;
//std::cout << myDTW.getCosts()[0] << std::endl;
//std::cout << myDTW.getCosts()[1] << std::endl;
assert(myDTW.getMaxLength() == 5);
assert(myDTW.getMinLength() == 4);
assert(myDTW.getMaxLength("first series") == 5);
assert(myDTW.getMinLength("first series") == 5);
assert(myDTW.getMaxLength("second series") == 4);
assert(myDTW.getMinLength("second series") == 4);
/*
fastDTW fastDtw;
std::cout << "fast one-two cost " << fastDtw.getCost(seriesOne, seriesTwo, 1) << std::endl;
......@@ -273,9 +278,9 @@ int main(int argc, const char * argv[]) {
myDTW.train(seriesVector);
inputSeries = tempSeries.input;
std::cout << "long match " << myDTW.run(inputSeries) << std::endl;
std::cout << myDTW.getCosts()[0] << std::endl;
std::cout << myDTW.getCosts()[1] << std::endl;
assert(myDTW.run(inputSeries) == "long down");
//std::cout << myDTW.getCosts()[0] << std::endl;
//std::cout << myDTW.getCosts()[1] << std::endl;
////////////////////////////////////////////////////////////////////////
......
......@@ -6,6 +6,7 @@
//
#include <vector>
#include <cassert>
#include "seriesClassification.h"
#ifdef EMSCRIPTEN
#include "emscripten/seriesClassificationEmbindings.h"
......@@ -16,29 +17,45 @@ seriesClassification::seriesClassification() {};
seriesClassification::~seriesClassification() {};
bool seriesClassification::train(const std::vector<trainingSeries> &seriesSet) {
assert(seriesSet.size() > 0);
reset();
bool trained = true;
allTrainingSeries = seriesSet;
minLength = maxLength = int(allTrainingSeries[0].input.size());
for (int i = 0; i < allTrainingSeries.size(); ++i) {
if (allTrainingSeries[i].input.size() < minLength) {
minLength = int(allTrainingSeries[i].input.size());
//Global
int newLength = int(allTrainingSeries[i].input.size());
if (newLength < minLength) {
minLength = newLength;
}
if (allTrainingSeries[i].input.size() > maxLength) {
maxLength = int(allTrainingSeries[i].input.size());
if (newLength > maxLength) {
maxLength = newLength;
}
//Per Label
std::map<std::string, lengths>::iterator it = lengthsPerLabel.find(allTrainingSeries[i].label);
if (it != lengthsPerLabel.end()) {
int newLength = int(allTrainingSeries[i].input.size());
if (newLength < it->second.min) {
it->second.min = newLength;
}
if (newLength > it->second.max) {
it->second.max = newLength;
}
} else {
lengths tempLengths;
tempLengths.min = tempLengths.max = int(allTrainingSeries[i].input.size());
lengthsPerLabel[allTrainingSeries[i].label] = tempLengths;
}
}
//TODO: calculate some size statistics here?
//min length per label
//max length per label
return trained;
};
void seriesClassification::reset() {
allCosts.clear();
allTrainingSeries.clear();
lengthsPerLabel.clear();
minLength = -1;
maxLength = -1;
}
std::string seriesClassification::run(const std::vector<std::vector<double>> &inputSeries) {
......@@ -65,6 +82,32 @@ std::vector<double> seriesClassification::getCosts() {
return allCosts;
}
int seriesClassification::getMinLength() {
return minLength;
}
int seriesClassification::getMinLength(std::string label) {
int labelMinLength = -1;
std::map<std::string, lengths>::iterator it = lengthsPerLabel.find(label);
if (it != lengthsPerLabel.end()) {
labelMinLength = it->second.min;
}
return labelMinLength;
}
int seriesClassification::getMaxLength() {
return maxLength;
}
int seriesClassification::getMaxLength(std::string label) {
int labelMaxLength = -1;
std::map<std::string, lengths>::iterator it = lengthsPerLabel.find(label);
if (it != lengthsPerLabel.end()) {
labelMaxLength = it->second.max;
}
return labelMaxLength;
}
//
//std::vector<double> seriesClassification::getCosts(const std::vector<trainingExample> &trainingSet) {
// run(trainingSet);
......
......@@ -11,6 +11,7 @@
#include <vector>
#include <string>
#include <map>
#include "fastDTW.h"
#include "trainingExample.h"
......@@ -25,12 +26,22 @@ public:
std::string run(const std::vector<std::vector<double>> &inputSeries);
std::vector<double> getCosts();
int getMinLength();
int getMinLength(std::string label);
int getMaxLength();
int getMaxLength(std::string label);
private:
std::vector<trainingSeries> allTrainingSeries;
std::vector<double> allCosts;
int maxLength;
int minLength;
struct lengths {
int min;
int max;
};
std::map<std::string, lengths> lengthsPerLabel;
};
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment