Commit c5891ae8 authored by Michael Zbyszyński's avatar Michael Zbyszyński
Browse files

DTW is working and tested

parent bb6692ae
......@@ -38,6 +38,25 @@ bool machineLearning<regression>::train(const trainingData &newTrainingData) {
return regression::train(trainingSet);
}
template<>
bool machineLearning<seriesClassification>::train(const trainingData &newTrainingData) {
std::vector<std::vector<trainingExample> > seriesSet;
for (int i = 1; i < newTrainingData.trainingSet.size(); ++i) { //each phrase
std::vector<trainingExample> tempSeries;
for (int j = 0; j < newTrainingData.trainingSet[i].elements.size(); ++j) { //each element
trainingExample tempExample;
tempExample.input = newTrainingData.trainingSet[i].elements[j].input;
tempSeries.push_back(tempExample);
}
seriesSet.push_back(tempSeries);
}
return seriesClassification::trainTrainingSet(seriesSet);
}
template<>
int machineLearning<seriesClassification>::run(const std::vector<std::vector<double> > &inputSeries) {
return seriesClassification::run(inputSeries);
}
RAPIDMIX_END_NAMESPACE
......@@ -13,6 +13,7 @@
////////// Include all of the machine learning algorithms here
#include "classification.h"
#include "regression.h"
#include "seriesClassification.h"
#include "rapidXmmTools.h"
#include "rapidGVF.h"
......@@ -37,9 +38,12 @@ public:
// Could overload this, or specialize, or both
std::vector<double> run(const std::vector<double> &inputVector) {
return MachineLearningModule::process(inputVector);
return MachineLearningModule::run(inputVector);
}
//* This is the one I'm using for DTW */
int run(const std::vector<std::vector<double> > &inputSeries);
bool reset() {
return MachineLearningModule::reset();
}
......@@ -52,6 +56,7 @@ private:
typedef machineLearning<classification> staticClassification;
typedef machineLearning<regression> staticRegression;
typedef machineLearning<seriesClassification> dtwTemporalClassification;
typedef xmmToolConfig xmmConfig;
typedef machineLearning<xmmGmmTool> xmmStaticClassification;
......
......@@ -146,14 +146,14 @@ bool xmmTool<SingleClassModel, Model>::readJSON(const std::string &filepath) {
//============================== xmmGmmTool ==================================//
std::vector<double> xmmGmmTool::process(const std::vector<double>& inputVector) {
std::vector<double> xmmGmmTool::run(const std::vector<double>& inputVector) {
xmmTool::preProcess(inputVector);
return model.results.smoothed_normalized_likelihoods;
}
//============================== xmmGmrTool ==================================//
std::vector<double> xmmGmrTool::process(const std::vector<double>& inputVector) {
std::vector<double> xmmGmrTool::run(const std::vector<double>& inputVector) {
xmmTool::preProcess(inputVector);
std::vector<float> *res = &model.results.output_values;
std::vector<double> dRes(res->begin(), res->end());
......@@ -162,7 +162,7 @@ std::vector<double> xmmGmrTool::process(const std::vector<double>& inputVector)
//============================== xmmHmmTool ==================================//
std::vector<double> xmmHmmTool::process(const std::vector<double>& inputVector) {
std::vector<double> xmmHmmTool::run(const std::vector<double>& inputVector) {
xmmTool::preProcess(inputVector);
std::vector<double> res;
......@@ -178,7 +178,7 @@ std::vector<double> xmmHmmTool::process(const std::vector<double>& inputVector)
//============================== xmmHmrTool ==================================//
std::vector<double> xmmHmrTool::process(const std::vector<double>& inputVector) {
std::vector<double> xmmHmrTool::run(const std::vector<double>& inputVector) {
xmmTool::preProcess(inputVector);
std::vector<float> *res = &model.results.output_values;
std::vector<double> dRes(res->begin(), res->end());
......
......@@ -216,7 +216,7 @@ public:
xmmStaticTool<xmm::GMM, xmm::GMM>(cfg, false) {}
~xmmGmmTool() {}
std::vector<double> process(const std::vector<double>& inputVector);
std::vector<double> run(const std::vector<double>& inputVector);
};
class xmmGmrTool : public xmmStaticTool<xmm::GMM, xmm::GMM> {
......@@ -225,7 +225,7 @@ public:
xmmStaticTool<xmm::GMM, xmm::GMM>(cfg, true) {}
~xmmGmrTool() {}
std::vector<double> process(const std::vector<double>& inputVector);
std::vector<double> run(const std::vector<double>& inputVector);
};
class xmmHmmTool : public xmmTemporalTool<xmm::HMM, xmm::HierarchicalHMM> {
......@@ -234,7 +234,7 @@ public:
xmmTemporalTool<xmm::HMM, xmm::HierarchicalHMM>(cfg, false) {}
~xmmHmmTool() {}
std::vector<double> process(const std::vector<double>& inputVector);
std::vector<double> run(const std::vector<double>& inputVector);
};
class xmmHmrTool : public xmmTemporalTool<xmm::HMM, xmm::HierarchicalHMM> {
......@@ -243,7 +243,7 @@ public:
xmmTemporalTool<xmm::HMM, xmm::HierarchicalHMM>(cfg, true) {}
~xmmHmrTool() {}
std::vector<double> process(const std::vector<double>& inputVector);
std::vector<double> run(const std::vector<double>& inputVector);
};
#endif /* _RAPID_XMM_TOOLS_H_ */
......@@ -62,7 +62,7 @@ SCENARIO("Test NN Regression", "[machineLearning]")
SCENARIO("Test kNN classification", "[machineLearning]")
{
GIVEN("kNN Regression object and training dataset")
GIVEN("kNN object and training dataset")
{
rapidmix::staticClassification myKnn;
......@@ -109,6 +109,59 @@ SCENARIO("Test kNN classification", "[machineLearning]")
}
}
SCENARIO("Test DTW classification", "[machineLearning]")
{
GIVEN("DTW object and training dataset")
{
rapidmix::dtwTemporalClassification myDTW;
rapidmix::trainingData myData;
myData.startRecording("setOne");
std::vector<double> input = { 0.1, 0.5 };
std::vector<double> output = {};
REQUIRE(myData.addElement(input, output) == 2); //TODO: Shouldn't this be 1?
input = { 0.2, 0.4 };
REQUIRE(myData.addElement(input, output) == 3);
input = { 0.3, 0.3 };
REQUIRE(myData.addElement(input, output) == 4);
input = { 0.4, 0.2 };
REQUIRE(myData.addElement(input, output) == 5);
input = { 0.5, 0.1 };
REQUIRE(myData.addElement(input, output) == 6);
myData.stopRecording();
myData.startRecording("setTwo");
input = { 1., 4. };
myData.addElement(input, output);
input = { 2., -3. };
myData.addElement(input, output);
input = { 1., 5. };
myData.addElement(input, output);
input = { -2., 1. };
myData.addElement(input, output);
myData.stopRecording();
REQUIRE(myDTW.train(myData) == true);
std::vector<std::vector<double> > inputSet1;
inputSet1.push_back( {1., 4.});
inputSet1.push_back( {2., -3.});
inputSet1.push_back( {1., 5.});
inputSet1.push_back( {-2., 1.});
REQUIRE(myDTW.run(inputSet1) == 1);
std::vector<std::vector<double> > inputSet0;
inputSet0.push_back( { 0.1, 0.5 });
inputSet0.push_back( { 0.2, 0.4 });
inputSet0.push_back( { 0.3, 0.3 });
inputSet0.push_back( { 0.4, 0.2 });
inputSet0.push_back( { 0.5, 0.1 });
REQUIRE(myDTW.run(inputSet0) == 0);
}
}
SCENARIO("Test both classes reject bad data", "[machineLearning]") {
rapidmix::staticRegression badNN;
rapidmix::staticClassification badKNN;
......
......@@ -44,9 +44,12 @@
BE9286491EF015AE006847CF /* test_signalProcessing.cpp in Sources */ = {isa = PBXBuildFile; fileRef = BE2C5EE31ED8480D00E9FAFA /* test_signalProcessing.cpp */; };
BE92864C1EF015E7006847CF /* rapidStream.cpp in Sources */ = {isa = PBXBuildFile; fileRef = BE92864A1EF015E7006847CF /* rapidStream.cpp */; };
BE92864D1EF01622006847CF /* rapidStream.cpp in Sources */ = {isa = PBXBuildFile; fileRef = BE92864A1EF015E7006847CF /* rapidStream.cpp */; };
BE9286511EF01801006847CF /* svmClassification.cpp in Sources */ = {isa = PBXBuildFile; fileRef = BE92864E1EF017E2006847CF /* svmClassification.cpp */; };
BE9286521EF01823006847CF /* libsvm.cpp in Sources */ = {isa = PBXBuildFile; fileRef = BE2C5F0F1EDD74BC00E9FAFA /* libsvm.cpp */; };
BE9286561EF01A2D006847CF /* seriesClassification.cpp in Sources */ = {isa = PBXBuildFile; fileRef = BE9286531EF01A23006847CF /* seriesClassification.cpp */; };
BE92865A1EF01C4A006847CF /* dtw.cpp in Sources */ = {isa = PBXBuildFile; fileRef = BE9286571EF01C45006847CF /* dtw.cpp */; };
BEA7B70F1EDD7B350003E84B /* machineLearning.cpp in Sources */ = {isa = PBXBuildFile; fileRef = BE2C5EBE1ED8459300E9FAFA /* machineLearning.cpp */; };
BEA7B7101EDD7B390003E84B /* trainingData.cpp in Sources */ = {isa = PBXBuildFile; fileRef = BE2C5EC01ED8459300E9FAFA /* trainingData.cpp */; };
BEA7B7111EDD7B640003E84B /* svm.cpp in Sources */ = {isa = PBXBuildFile; fileRef = BE2C5EED1ED849AA00E9FAFA /* svm.cpp */; };
BEA7B7121EDD7B660003E84B /* classification.cpp in Sources */ = {isa = PBXBuildFile; fileRef = BE2C5DEA1ED8450E00E9FAFA /* classification.cpp */; };
BEA7B7131EDD7B690003E84B /* knnClassification.cpp in Sources */ = {isa = PBXBuildFile; fileRef = BE2C5DF31ED8450E00E9FAFA /* knnClassification.cpp */; };
BEA7B7141EDD7B6B0003E84B /* modelSet.cpp in Sources */ = {isa = PBXBuildFile; fileRef = BE2C5DF51ED8450E00E9FAFA /* modelSet.cpp */; };
......@@ -241,8 +244,6 @@
BE2C5EE11ED8480D00E9FAFA /* test_rapidPiPoTools.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = test_rapidPiPoTools.cpp; sourceTree = "<group>"; };
BE2C5EE21ED8480D00E9FAFA /* test_rapidXmmTools.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = test_rapidXmmTools.cpp; sourceTree = "<group>"; };
BE2C5EE31ED8480D00E9FAFA /* test_signalProcessing.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = test_signalProcessing.cpp; sourceTree = "<group>"; };
BE2C5EED1ED849AA00E9FAFA /* svm.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = svm.cpp; sourceTree = "<group>"; };
BE2C5EEE1ED849AA00E9FAFA /* svm.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = svm.h; sourceTree = "<group>"; };
BE2C5EF21EDD73D000E9FAFA /* mimo_stats.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = mimo_stats.h; sourceTree = "<group>"; };
BE2C5EF31EDD73D000E9FAFA /* PiPoBands.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = PiPoBands.h; sourceTree = "<group>"; };
BE2C5EF41EDD73D000E9FAFA /* PiPoBayesFilter.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = PiPoBayesFilter.h; sourceTree = "<group>"; };
......@@ -337,6 +338,12 @@
BE2C5F561EDD74FC00E9FAFA /* rta_configuration.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = rta_configuration.h; sourceTree = "<group>"; };
BE92864A1EF015E7006847CF /* rapidStream.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = rapidStream.cpp; sourceTree = "<group>"; };
BE92864B1EF015E7006847CF /* rapidStream.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = rapidStream.h; sourceTree = "<group>"; };
BE92864E1EF017E2006847CF /* svmClassification.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = svmClassification.cpp; sourceTree = "<group>"; };
BE92864F1EF017E2006847CF /* svmClassification.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = svmClassification.h; sourceTree = "<group>"; };
BE9286531EF01A23006847CF /* seriesClassification.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = seriesClassification.cpp; sourceTree = "<group>"; };
BE9286541EF01A23006847CF /* seriesClassification.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = seriesClassification.h; sourceTree = "<group>"; };
BE9286571EF01C45006847CF /* dtw.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = dtw.cpp; sourceTree = "<group>"; };
BE9286581EF01C45006847CF /* dtw.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = dtw.h; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
......@@ -677,11 +684,15 @@
BE2C5DE81ED8450E00E9FAFA /* src */ = {
isa = PBXGroup;
children = (
BE9286571EF01C45006847CF /* dtw.cpp */,
BE9286581EF01C45006847CF /* dtw.h */,
BE9286531EF01A23006847CF /* seriesClassification.cpp */,
BE9286541EF01A23006847CF /* seriesClassification.h */,
BE92864E1EF017E2006847CF /* svmClassification.cpp */,
BE92864F1EF017E2006847CF /* svmClassification.h */,
BE92864A1EF015E7006847CF /* rapidStream.cpp */,
BE92864B1EF015E7006847CF /* rapidStream.h */,
BE2C5DE91ED8450E00E9FAFA /* baseModel.h */,
BE2C5EED1ED849AA00E9FAFA /* svm.cpp */,
BE2C5EEE1ED849AA00E9FAFA /* svm.h */,
BE2C5DEA1ED8450E00E9FAFA /* classification.cpp */,
BE2C5DEB1ED8450E00E9FAFA /* classification.h */,
BE2C5DF31ED8450E00E9FAFA /* knnClassification.cpp */,
......@@ -1128,14 +1139,17 @@
files = (
BEA7B7181EDD7C0E0003E84B /* test_RapidLib.cpp in Sources */,
BEA7B7161EDD7B700003E84B /* regression.cpp in Sources */,
BE92865A1EF01C4A006847CF /* dtw.cpp in Sources */,
BEA7B7151EDD7B6E0003E84B /* neuralNetwork.cpp in Sources */,
BEA7B7141EDD7B6B0003E84B /* modelSet.cpp in Sources */,
BE9286561EF01A2D006847CF /* seriesClassification.cpp in Sources */,
BEA7B7131EDD7B690003E84B /* knnClassification.cpp in Sources */,
BEA7B7121EDD7B660003E84B /* classification.cpp in Sources */,
BEA7B70F1EDD7B350003E84B /* machineLearning.cpp in Sources */,
BEA7B7101EDD7B390003E84B /* trainingData.cpp in Sources */,
BEA7B7111EDD7B640003E84B /* svm.cpp in Sources */,
BEA7B7171EDD7B7E0003E84B /* jsoncpp.cpp in Sources */,
BE9286511EF01801006847CF /* svmClassification.cpp in Sources */,
BE9286521EF01823006847CF /* libsvm.cpp in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment