Commit e3ec45c1 authored by mzed's avatar mzed
Browse files

dtw has no state now

parent 75b62040
......@@ -195,93 +195,23 @@ int main(int argc, const char * argv[]) {
////////////
//////////////////////////////////////////////////////////////////////// DTW
seriesClassification myDtw;
seriesClassification myDtwTrain;
//Testing addSeries()
//Test series
std::vector<std::vector<double>> seriesOne;
seriesOne.push_back( { 1., 5.} );
seriesOne.push_back( { 2., 4.} );
seriesOne.push_back( { 3., 3.} );
seriesOne.push_back( { 4., 2.} );
seriesOne.push_back( { 5., 1.} );
myDtw.addSeries(seriesOne);
std::vector<std::vector<double>> seriesTwo;
seriesTwo.push_back( { 1., 4. } );
seriesTwo.push_back( { 2., -3. } );
seriesTwo.push_back( { 1., 5. } );
seriesTwo.push_back( { -2., 1. } );
myDtw.addSeries(seriesTwo);
assert(myDtw.run(seriesOne) == 0);
assert(myDtw.getCosts()[0] == 0);
assert(myDtw.getCosts()[1] == 19.325403217417502);
assert(myDtw.run(seriesTwo) == 1);
//testing train()
std::vector<std::vector<std::vector<double>>> seriesSet;
seriesSet.push_back(seriesOne);
seriesSet.push_back(seriesTwo);
myDtwTrain.train(seriesSet);
assert(myDtwTrain.run(seriesOne) == 0);
assert(myDtwTrain.run(seriesTwo) == 1);
//Testing with training examples... probably don't neet this?
seriesClassification myDtw2;
std::vector<trainingExample> tsOne;
tempExample.input = { 1., 5. };
tsOne.push_back(tempExample);
tempExample.input = { 2., 4. };
tsOne.push_back(tempExample);
tempExample.input = { 3., 3. };
tsOne.push_back(tempExample);
tempExample.input = { 4., 2. };
tsOne.push_back(tempExample);
tempExample.input = { 5., 1. };
tsOne.push_back(tempExample);
myDtw2.addSeries(tsOne);
std::vector<trainingExample> tsTwo;
tempExample.input = { 1., 4. };
tsTwo.push_back(tempExample);
tempExample.input = { 2., -3. };
tsTwo.push_back(tempExample);
tempExample.input = { 1., 5. };
tsTwo.push_back(tempExample);
tempExample.input = { -2., 1. };
tsTwo.push_back(tempExample);
myDtw2.addSeries(tsTwo);
assert(myDtw2.run(tsOne) == 0);
assert(myDtw2.run(tsTwo) == 1);
seriesClassification myDtw2T;
seriesSet.clear();
seriesSet.push_back(seriesOne);
seriesSet.push_back(seriesTwo);
myDtw2T.train(seriesSet);
assert(myDtw2T.run(tsOne) == 0);
assert(myDtw2T.run(tsTwo) == 1);
assert(myDtw.getCosts()[0] == 19.325403217417502);
assert(myDtw2T.getCosts(tsOne)[0] == 0);
///////////////////////////////////////////////////////////Testing with labels
seriesClassification myDtwLabel;
std::vector<trainingSeries> seriesVector;
......
......@@ -8,44 +8,49 @@
#include <vector>
#include <cmath>
#include <cassert>
#include "dtw.h"
dtw::dtw() {};
dtw::~dtw() {};
void dtw::setSeries(std::vector<std::vector<double>> newSeries) {
storedSeries = newSeries;
numFeatures = int(storedSeries[0].size());
};
//void dtw::setSeries(std::vector<std::vector<double>> newSeries) {
// storedSeries = newSeries;
// numFeatures = int(storedSeries[0].size());
//};
inline double dtw::distanceFunction(std::vector<double> x, std::vector<double> y) {
//TODO: make sure series are same length
inline double dtw::distanceFunction(const std::vector<double> &x, const std::vector<double> &y) {
assert(x.size() == y.size());
double euclidianDistance = 0;
for(int j = 0; j < numFeatures ; ++j){
for(int j = 0; j < x.size() ; ++j){
euclidianDistance = euclidianDistance + pow((x[j] - y[j]), 2);
}
euclidianDistance = sqrt(euclidianDistance);
return euclidianDistance;
};
double dtw::run(std::vector<std::vector<double> > inputSeries) {
std::vector<std::vector<double> > costMatrix(inputSeries.size(), std::vector<double>(storedSeries.size(), 0));
int maxInput = int(inputSeries.size()) - 1;
int maxStored = int(storedSeries.size()) - 1;
double dtw::getCost(const std::vector<std::vector<double> > &seriesX, const std::vector<std::vector<double> > &seriesY) {
if (seriesX.size() < seriesY.size()) {
return getCost(seriesY, seriesX);
}
std::vector<std::vector<double> > costMatrix(seriesX.size(), std::vector<double>(seriesY.size(), 0));
int maxInput = int(seriesX.size()) - 1;
int maxStored = int(seriesY.size()) - 1;
//Calculate values for the first column
costMatrix[0][0] = distanceFunction(inputSeries[0], storedSeries[0]);
costMatrix[0][0] = distanceFunction(seriesX[0], seriesY[0]);
for (int j = 1; j <= maxStored; ++j) {
costMatrix[0][j] = costMatrix[0][j - 1] + distanceFunction(inputSeries[0], storedSeries[j]);
costMatrix[0][j] = costMatrix[0][j - 1] + distanceFunction(seriesX[0], seriesY[j]);
}
for (int i = 1; i <= maxInput; ++i) {
//Bottom row of current column
costMatrix[i][0] = costMatrix[i - 1][0] + distanceFunction(inputSeries[i], storedSeries[0]);
costMatrix[i][0] = costMatrix[i - 1][0] + distanceFunction(seriesX[i], seriesY[0]);
for (int j = 1; j <= maxStored; ++j) {
double minGlobalCost = fmin(costMatrix[i-1][j-1], costMatrix[i][j-1]);
costMatrix[i][j] = minGlobalCost + distanceFunction(inputSeries[i], storedSeries[j]);
costMatrix[i][j] = minGlobalCost + distanceFunction(seriesX[i], seriesY[j]);
}
}
double minimumCost = costMatrix[maxInput][maxStored];
......
......@@ -17,14 +17,14 @@ public:
dtw();
~dtw();
void setSeries(std::vector<std::vector<double>> newSeries);
double run(std::vector<std::vector<double>> inputSeries);
void reset();
//void setSeries(std::vector<std::vector<double>> newSeries);
double getCost(const std::vector<std::vector<double>> &seriesX, const std::vector<std::vector<double > > &seriesY);
//void reset();
private:
std::vector<std::vector<double>> storedSeries;
int numFeatures;
inline double distanceFunction(std::vector<double> seriesX, std::vector<double> seriesY);
//std::vector<std::vector<double>> storedSeries;
//int numFeatures;
inline double distanceFunction(const std::vector<double> &pointX, const std::vector<double> &point);
};
......
......@@ -15,99 +15,45 @@ seriesClassification::seriesClassification() {};
seriesClassification::~seriesClassification() {};
bool seriesClassification::addSeries(const std::vector<std::vector<double>> &newSeries) {
dtw newDTW;
newDTW.setSeries(newSeries);
dtwClassifiers.push_back(newDTW);
return true;
}
bool seriesClassification::addSeries(const std::vector<trainingExample> &trainingSet) {
std::vector<std::vector<double>> newSeries;
for (int i = 0; i < trainingSet.size(); ++i) {
newSeries.push_back(trainingSet[i].input);
}
return addSeries(newSeries);
};
///////////////////////////////////////////////// Training
//TODO: Refactor these
bool seriesClassification::train(const std::vector<std::vector<std::vector<double> > > &vectorSet) {
bool trained = true;
bool seriesClassification::trainLabel(const std::vector<trainingSeries> &seriesSet) {
reset();
for (int i = 0; i < vectorSet.size(); ++i) {
if (!addSeries(vectorSet[i])) {
trained = false;
};
}
return trained;
}
bool seriesClassification::train(const std::vector<std::vector<trainingExample> > &exampleSet) {
bool trained = true;
reset();
for (int i = 0; i < exampleSet.size(); ++i) {
if (!addSeries(exampleSet[i])) {
trained = false;
};
}
return trained;
}
allTrainingSeries = seriesSet;
bool seriesClassification::trainLabel(const std::vector<trainingSeries> &seriesSet) {
bool trained = true;
reset();
for (int i = 0; i < seriesSet.size(); ++i) {
if(!addSeries(seriesSet[i].input) ) {
trained = false;
}
labels.push_back(seriesSet[i].label);
}
//TODO: calculate some size statistics here?
return trained;
};
/////////////////////////////////////////////////
void seriesClassification::reset() {
labels.clear();
dtwClassifiers.clear();
allCosts.clear();
allTrainingSeries.clear();
}
int seriesClassification::run(const std::vector<std::vector<double>> &inputSeries) {
//TODO: check vector sizes and reject bad data
std::string seriesClassification::runLabel(const std::vector<std::vector<double>> &inputSeries) {
dtw dtw;
int closestSeries = 0;
allCosts.clear();
double lowestCost = dtwClassifiers[0].run(inputSeries);
double lowestCost = dtw.getCost(inputSeries, allTrainingSeries[0].input);
allCosts.push_back(lowestCost);
for (int i = 1; i < dtwClassifiers.size(); ++i) {
double currentCost = dtwClassifiers[i].run(inputSeries);
for (int i = 1; i < allTrainingSeries.size(); ++i) {
double currentCost = dtw.getCost(inputSeries, allTrainingSeries[i].input);
allCosts.push_back(currentCost);
if (currentCost < lowestCost) {
lowestCost = currentCost;
closestSeries = i;
}
}
return closestSeries;
return allTrainingSeries[closestSeries].label;
};
int seriesClassification::run(const std::vector<trainingExample> &trainingSet) {
std::vector<std::vector<double>> newSeries;
for (int i = 0; i < trainingSet.size(); ++i) {
newSeries.push_back(trainingSet[i].input);
}
return run(newSeries);
};
std::string seriesClassification::runLabel(const std::vector<std::vector<double>> &inputSeries) {
return labels[run(inputSeries)];
};
std::vector<double> seriesClassification::getCosts() {
return allCosts;
}
std::vector<double> seriesClassification::getCosts(const std::vector<trainingExample> &trainingSet) {
run(trainingSet);
return allCosts;
}
\ No newline at end of file
//std::vector<double> seriesClassification::getCosts() {
// return allCosts;
//}
//
//std::vector<double> seriesClassification::getCosts(const std::vector<trainingExample> &trainingSet) {
// run(trainingSet);
// return allCosts;
//}
\ No newline at end of file
......@@ -20,26 +20,15 @@ public:
seriesClassification();
~seriesClassification();
bool addSeries(const std::vector<std::vector<double>> &newSeries);
bool addSeries(const std::vector<trainingExample> &trainingSet);
bool train(const std::vector<std::vector<std::vector<double> > > &vectorSet);
bool train(const std::vector<std::vector<trainingExample>> &exampleSet);
bool trainLabel(const std::vector<trainingSeries> &seriesSet);
void reset();
int run(const std::vector<std::vector<double>> &inputSeries);
int run(const std::vector<trainingExample> &inputSet);
std::string runLabel(const std::vector<std::vector<double>> &inputSeries);
std::vector<double> getCosts();
std::vector<double> getCosts(const std::vector<trainingExample> &inputSet);
private:
std::vector<std::string> labels;
std::vector<dtw> dtwClassifiers;
std::vector<trainingSeries> allTrainingSeries;
std::vector<double> allCosts;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment