Commit 57473e23 authored by mzed's avatar mzed
Browse files

new train methods

parent 900371e0
......@@ -135,6 +135,7 @@ int main(int argc, const char * argv[]) {
////////////
seriesClassification myDtw;
seriesClassification myDtwTrain;
std::vector<std::vector<double>> seriesOne;
seriesOne.push_back( { 1., 5.} );
......@@ -154,6 +155,14 @@ int main(int argc, const char * argv[]) {
assert(myDtw.run(seriesOne) == 0);
assert(myDtw.run(seriesTwo) == 1);
std::vector<std::vector<std::vector<double>>> seriesSet;
seriesSet.push_back(seriesOne);
seriesSet.push_back(seriesTwo);
myDtwTrain.train(seriesSet);
assert(myDtwTrain.run(seriesOne) == 0);
assert(myDtwTrain.run(seriesTwo) == 1);
seriesClassification myDtw2;
std::vector<trainingExample> tsOne;
......@@ -192,8 +201,17 @@ int main(int argc, const char * argv[]) {
myDtw2.addTrainingSet(tsTwo);
std::cout << "dtw2: " << myDtw2.runTrainingSet(tsOne) << std::endl;
assert(myDtw2.runTrainingSet(tsOne) == 0);
assert(myDtw2.runTrainingSet(tsTwo) == 1);
seriesClassification myDtw2T;
seriesSet.clear();
seriesSet.push_back(seriesOne);
seriesSet.push_back(seriesTwo);
myDtw2T.train(seriesSet);
assert(myDtw2T.runTrainingSet(tsOne) == 0);
assert(myDtw2T.runTrainingSet(tsTwo) == 1);
return 0;
}
......@@ -16,7 +16,7 @@ seriesClassification::seriesClassification() {};
seriesClassification::~seriesClassification() {};
bool seriesClassification::addSeries(std::vector<std::vector<double>> newSeries) {
bool seriesClassification::addSeries(const std::vector<std::vector<double>> &newSeries) {
dtw newDTW;
newDTW.setSeries(newSeries);
dtwClassifiers.push_back(newDTW);
......@@ -31,12 +31,33 @@ bool seriesClassification::addTrainingSet(const std::vector<trainingExample> &tr
return addSeries(newSeries);
};
bool seriesClassification::train(const std::vector<std::vector<std::vector<double> > > &newSeriesSet) {
bool trained = true;
reset();
for (int i = 0; i < newSeriesSet.size(); ++i) {
if (!addSeries(newSeriesSet[i])) {
trained = false;
};
}
return trained;
}
bool seriesClassification::trainTrainingSet(const std::vector<std::vector<trainingExample> > &seriesSet) {
bool trained = true;
reset();
for (int i = 0; i < seriesSet.size(); ++i) {
if (!addTrainingSet(seriesSet[i])) {
trained = false;
};
}
return trained;
}
void seriesClassification::reset() {
dtwClassifiers.clear();
}
int seriesClassification::run(std::vector<std::vector<double>> inputSeries) {
int seriesClassification::run(const std::vector<std::vector<double>> &inputSeries) {
//TODO: check vector sizes and reject bad data
int closestSeries = 0;
double lowestCost = dtwClassifiers[0].run(inputSeries);
......
......@@ -19,13 +19,16 @@ public:
seriesClassification();
~seriesClassification();
bool addSeries(std::vector<std::vector<double>> newSeries);
bool addTrainingSet(const std::vector<trainingExample> &trainingSet); //hacky solution for JavaScipt. -mz
bool addSeries(const std::vector<std::vector<double>> &newSeries);
bool addTrainingSet(const std::vector<trainingExample> &trainingSet);
bool train(const std::vector<std::vector<std::vector<double>>> &newSeriesSet);
bool trainTrainingSet(const std::vector<std::vector<trainingExample>> &seriesSet);
void reset();
int run(std::vector<std::vector<double>> inputSeries);
int runTrainingSet(const std::vector<trainingExample> &trainingSet);
int run(const std::vector<std::vector<double>> &inputSeries);
int runTrainingSet(const std::vector<trainingExample> &inputSet);
private:
std::vector<dtw> dtwClassifiers;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment