Commit dd605c5d authored by Michael Zbyszyński's avatar Michael Zbyszyński
Browse files

dtwTemporalClassification now takes labeled data and returns a label

parent 933a5701
RapidLib @ 337e093e
Subproject commit 19fd14352871f6795f2511123c4b438a67cec197
Subproject commit 337e093e00a8ed4aebd6b33bbd390aff6d1458b5
......@@ -40,22 +40,21 @@ bool machineLearning<regression>::train(const trainingData &newTrainingData) {
template<>
bool machineLearning<seriesClassification>::train(const trainingData &newTrainingData) {
std::vector<std::vector<trainingExample> > seriesSet;
std::vector<trainingSeries> seriesSet;
for (int i = 1; i < newTrainingData.trainingSet.size(); ++i) { //each phrase
std::vector<trainingExample> tempSeries;
trainingSeries tempSeries;
tempSeries.label = newTrainingData.trainingSet[i].label;
for (int j = 0; j < newTrainingData.trainingSet[i].elements.size(); ++j) { //each element
trainingExample tempExample;
tempExample.input = newTrainingData.trainingSet[i].elements[j].input;
tempSeries.push_back(tempExample);
tempSeries.input.push_back(newTrainingData.trainingSet[i].elements[j].input);
}
seriesSet.push_back(tempSeries);
}
return seriesClassification::trainTrainingSet(seriesSet);
return seriesClassification::trainLabel(seriesSet);
}
template<>
int machineLearning<seriesClassification>::run(const std::vector<std::vector<double> > &inputSeries) {
return seriesClassification::run(inputSeries);
std::string machineLearning<seriesClassification>::run(const std::vector<std::vector<double> > &inputSeries) {
return seriesClassification::runLabel(inputSeries);
}
......
......@@ -57,7 +57,7 @@ public:
}
//* This is the one I'm using for DTW */
int run(const std::vector<std::vector<double> > &inputSeries);
std::string run(const std::vector<std::vector<double> > &inputSeries);
bool reset() {
return MachineLearningModule::reset();
......
......@@ -149,7 +149,7 @@ SCENARIO("Test DTW classification", "[machineLearning]")
inputSet1.push_back( {1., 5.});
inputSet1.push_back( {-2., 1.});
REQUIRE(myDTW.run(inputSet1) == 1);
REQUIRE(myDTW.run(inputSet1) == "setTwo");
std::vector<std::vector<double> > inputSet0;
inputSet0.push_back( { 0.1, 0.5 });
......@@ -158,7 +158,7 @@ SCENARIO("Test DTW classification", "[machineLearning]")
inputSet0.push_back( { 0.4, 0.2 });
inputSet0.push_back( { 0.5, 0.1 });
REQUIRE(myDTW.run(inputSet0) == 0);
REQUIRE(myDTW.run(inputSet0) == "setOne");
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment