Commit ddaa176f authored by mzed's avatar mzed
Browse files

adding label feature for DTW

parent 0afbc6ab
......@@ -137,6 +137,7 @@ int main(int argc, const char * argv[]) {
seriesClassification myDtw;
seriesClassification myDtwTrain;
//Testing addSeries()
std::vector<std::vector<double>> seriesOne;
seriesOne.push_back( { 1., 5.} );
seriesOne.push_back( { 2., 4.} );
......@@ -155,9 +156,10 @@ int main(int argc, const char * argv[]) {
assert(myDtw.run(seriesOne) == 0);
assert(myDtw.getCosts()[0] == 0);
assert(myDtw.getCosts()[1] == 19.325403217417502);
assert(myDtw.run(seriesTwo) == 1);
//testing train()
std::vector<std::vector<std::vector<double>>> seriesSet;
seriesSet.push_back(seriesOne);
seriesSet.push_back(seriesTwo);
......@@ -166,7 +168,7 @@ int main(int argc, const char * argv[]) {
assert(myDtwTrain.run(seriesOne) == 0);
assert(myDtwTrain.run(seriesTwo) == 1);
//Testing with training examples... probably don't neet this?
seriesClassification myDtw2;
std::vector<trainingExample> tsOne;
......@@ -185,7 +187,7 @@ int main(int argc, const char * argv[]) {
tempExample.input = { 5., 1. };
tsOne.push_back(tempExample);
myDtw2.addTrainingSet(tsOne);
myDtw2.addSeries(tsOne);
std::vector<trainingExample> tsTwo;
tempExample.input = { 1., 4. };
......@@ -200,11 +202,11 @@ int main(int argc, const char * argv[]) {
tempExample.input = { -2., 1. };
tsTwo.push_back(tempExample);
myDtw2.addTrainingSet(tsTwo);
myDtw2.addSeries(tsTwo);
assert(myDtw2.runTrainingSet(tsOne) == 0);
assert(myDtw2.runTrainingSet(tsTwo) == 1);
assert(myDtw2.run(tsOne) == 0);
assert(myDtw2.run(tsTwo) == 1);
seriesClassification myDtw2T;
seriesSet.clear();
......@@ -212,11 +214,35 @@ int main(int argc, const char * argv[]) {
seriesSet.push_back(seriesTwo);
myDtw2T.train(seriesSet);
assert(myDtw2T.runTrainingSet(tsOne) == 0);
assert(myDtw2T.runTrainingSet(tsTwo) == 1);
assert(myDtw2T.run(tsOne) == 0);
assert(myDtw2T.run(tsTwo) == 1);
assert(myDtw.getCosts()[0] == 19.325403217417502);
assert(myDtw2T.getCosts(tsOne)[0] == 0);
///////////////////////////////////////////////////////////Testing with labels
seriesClassification myDtwLabel;
std::vector<trainingSeries> seriesVector;
trainingSeries tempSeries;
tempSeries.input.push_back( { 1., 5.} );
tempSeries.input.push_back( { 2., 4.} );
tempSeries.input.push_back( { 3., 3.} );
tempSeries.input.push_back( { 4., 2.} );
tempSeries.input.push_back( { 5., 1.} );
tempSeries.label = "first series";
seriesVector.push_back(tempSeries);
tempSeries = {};
tempSeries.input.push_back( { 1., 4.} );
tempSeries.input.push_back( { 2., -3.} );
tempSeries.input.push_back( { 1., 5.} );
tempSeries.input.push_back( { -2., 1.} );
tempSeries.label = "second series";
seriesVector.push_back(tempSeries);
myDtwLabel.trainLabel(seriesVector);
std::cout << "label test 1: " << myDtwLabel.runLabel(seriesOne) << std::endl;
std::cout << "label test 2: " << myDtwLabel.runLabel(seriesTwo) << std::endl;
return 0;
}
......@@ -287,6 +287,7 @@ double neuralNetwork::run(const std::vector<double> &inputVector) {
for (int k=0; k <= numHiddenNodes; ++k){
outputNeuron += hiddenNeurons[numHiddenLayers - 1][k] * wHiddenOutput[k];
}
//if classifier, outputNeuron = activationFunction(outputNeuron), else...
outputNeuron = (outputNeuron * outRange) + outBase;
return outputNeuron;
}
......
//
// seriesClassification.cpp
// RapidAPI
//
// Created by mzed on 08/06/2017.
// Created by Michael Zbyszynski on 08/06/2017.
// Copyright © 2017 Goldsmiths. All rights reserved.
//
......@@ -23,7 +22,7 @@ bool seriesClassification::addSeries(const std::vector<std::vector<double>> &new
return true;
}
bool seriesClassification::addTrainingSet(const std::vector<trainingExample> &trainingSet) {
bool seriesClassification::addSeries(const std::vector<trainingExample> &trainingSet) {
std::vector<std::vector<double>> newSeries;
for (int i = 0; i < trainingSet.size(); ++i) {
newSeries.push_back(trainingSet[i].input);
......@@ -31,29 +30,47 @@ bool seriesClassification::addTrainingSet(const std::vector<trainingExample> &tr
return addSeries(newSeries);
};
bool seriesClassification::train(const std::vector<std::vector<std::vector<double> > > &newSeriesSet) {
///////////////////////////////////////////////// Training
//TODO: Refactor these
bool seriesClassification::train(const std::vector<std::vector<std::vector<double> > > &vectorSet) {
bool trained = true;
reset();
for (int i = 0; i < newSeriesSet.size(); ++i) {
if (!addSeries(newSeriesSet[i])) {
for (int i = 0; i < vectorSet.size(); ++i) {
if (!addSeries(vectorSet[i])) {
trained = false;
};
}
return trained;
}
bool seriesClassification::trainTrainingSet(const std::vector<std::vector<trainingExample> > &seriesSet) {
bool seriesClassification::train(const std::vector<std::vector<trainingExample> > &exampleSet) {
bool trained = true;
reset();
for (int i = 0; i < seriesSet.size(); ++i) {
if (!addTrainingSet(seriesSet[i])) {
for (int i = 0; i < exampleSet.size(); ++i) {
if (!addSeries(exampleSet[i])) {
trained = false;
};
}
return trained;
}
bool seriesClassification::trainLabel(const std::vector<trainingSeries> &seriesSet) {
bool trained = true;
reset();
for (int i = 0; i < seriesSet.size(); ++i) {
if(!addSeries(seriesSet[i].input) ) {
trained = false;
}
labels.push_back(seriesSet[i].label);
}
return trained;
};
/////////////////////////////////////////////////
void seriesClassification::reset() {
labels.clear();
dtwClassifiers.clear();
}
......@@ -74,7 +91,7 @@ int seriesClassification::run(const std::vector<std::vector<double>> &inputSerie
return closestSeries;
};
int seriesClassification::runTrainingSet(const std::vector<trainingExample> &trainingSet) {
int seriesClassification::run(const std::vector<trainingExample> &trainingSet) {
std::vector<std::vector<double>> newSeries;
for (int i = 0; i < trainingSet.size(); ++i) {
newSeries.push_back(trainingSet[i].input);
......@@ -82,11 +99,15 @@ int seriesClassification::runTrainingSet(const std::vector<trainingExample> &tra
return run(newSeries);
};
std::string seriesClassification::runLabel(const std::vector<std::vector<double>> &inputSeries) {
return labels[run(inputSeries)];
};
std::vector<double> seriesClassification::getCosts() {
return allCosts;
}
std::vector<double> seriesClassification::getCosts(const std::vector<trainingExample> &trainingSet) {
runTrainingSet(trainingSet);
run(trainingSet);
return allCosts;
}
\ No newline at end of file
//
// seriesClassification.hp
// seriesClassification.h
// RapidAPI
//
// Created by mzed on 08/06/2017.
// Created by Michael Zbyszynski on 08/06/2017.
// Copyright © 2017 Goldsmiths. All rights reserved.
//
......@@ -10,6 +10,7 @@
#define seriesClassification_hpp
#include <vector>
#include <string>
#include "dtw.h"
#include "trainingExample.h"
......@@ -20,20 +21,24 @@ public:
~seriesClassification();
bool addSeries(const std::vector<std::vector<double>> &newSeries);
bool addTrainingSet(const std::vector<trainingExample> &trainingSet);
bool addSeries(const std::vector<trainingExample> &trainingSet);
bool train(const std::vector<std::vector<std::vector<double>>> &newSeriesSet);
bool trainTrainingSet(const std::vector<std::vector<trainingExample>> &seriesSet);
bool train(const std::vector<std::vector<std::vector<double> > > &vectorSet);
bool train(const std::vector<std::vector<trainingExample>> &exampleSet);
bool trainLabel(const std::vector<trainingSeries> &seriesSet);
void reset();
int run(const std::vector<std::vector<double>> &inputSeries);
int runTrainingSet(const std::vector<trainingExample> &inputSet);
int run(const std::vector<trainingExample> &inputSet);
std::string runLabel(const std::vector<std::vector<double>> &inputSeries);
std::vector<double> getCosts();
std::vector<double> getCosts(const std::vector<trainingExample> &inputSet);
private:
std::vector<std::string> labels;
std::vector<dtw> dtwClassifiers;
std::vector<double> allCosts;
......
......@@ -2,6 +2,7 @@
#define trainingExample_h
#include <vector>
#include <string>
/** This is used by both NN and KNN models for training */
struct trainingExample {
......@@ -9,4 +10,9 @@ struct trainingExample {
std::vector<double> output;
};
struct trainingSeries {
std::vector<std::vector<double> > input;
std::string label;
};
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment