Commit 9bd97964 authored by mzed's avatar mzed
Browse files

pushed dtw changes through JavaScript. Tested.

parent e3ec45c1
......@@ -213,7 +213,7 @@ int main(int argc, const char * argv[]) {
///////////////////////////////////////////////////////////Testing with labels
seriesClassification myDtwLabel;
seriesClassification myDTW;
std::vector<trainingSeries> seriesVector;
trainingSeries tempSeries;
......@@ -233,9 +233,10 @@ int main(int argc, const char * argv[]) {
tempSeries.label = "second series";
seriesVector.push_back(tempSeries);
myDtwLabel.trainLabel(seriesVector);
assert(myDtwLabel.runLabel(seriesOne) == "first series");
assert(myDtwLabel.runLabel(seriesTwo) == "second series");
myDTW.train(seriesVector);
assert(myDTW.run(seriesOne) == "first series");
assert(myDTW.run(seriesTwo) == "second series");
std::cout << myDTW.getCosts()[0] << std::endl;
return 0;
}
......@@ -437,15 +437,6 @@ Module.SeriesClassification = function () {
};
Module.SeriesClassification.prototype = {
/**
* Adds a series to the array examples
* @param {Object} newSeries - An array of arrays
* @returns {Number} - index of the example series that best matches the input
*/
// addSeries: function (newSeries) {
// newSeries = Module.checkOutput(newSeries);
// return this.seriesClassification.addTrainingSet(Module.prepTrainingSet(newSeries));
// },
/**
* Resets the model, and adds a set of series to be evaluated
* @param {Object} newSeriesSet - an array of objects, each with input: <array of arrays> and label: <string>
......@@ -453,11 +444,7 @@ Module.SeriesClassification.prototype = {
*/
train: function (newSeriesSet) {
this.reset();
this.seriesClassification.trainLabel(Module.prepTrainingSeriesSet(newSeriesSet));
// for (var i = 0; i < newSeriesSet.length; ++i) {
// newSeriesSet[i] = Module.checkOutput(newSeriesSet[i]);
// this.seriesClassification.addTrainingSet(Module.prepTrainingSet(newSeriesSet[i]));
// }
this.seriesClassification.train(Module.prepTrainingSeriesSet(newSeriesSet));
return true;
},
/**
......@@ -481,7 +468,7 @@ Module.SeriesClassification.prototype = {
}
vecInputSeries.push_back(tempVector);
}
return this.seriesClassification.runLabel(vecInputSeries);
return this.seriesClassification.run(vecInputSeries);
},
/**
......@@ -494,14 +481,9 @@ Module.SeriesClassification.prototype = {
},
/**
* Returns an array of costs to match the input series to each example series. A lower cost is a closer match
* @param {Array} [inputSeries] - An array of arrays to be evaluated. (Optional)
* @returns {Array}
*/
getCosts: function (inputSeries) {
if (inputSeries) {
inputSeries = Module.checkOutput(inputSeries);
this.seriesClassification.runTrainingSet(Module.prepTrainingSet(inputSeries));
}
getCosts: function () {
let returnArray = [];
let VecDouble = this.seriesClassification.getCosts();
for (let i = 0; i < VecDouble.size(); ++i) {
......
......@@ -8,12 +8,9 @@ using namespace emscripten;
EMSCRIPTEN_BINDINGS(seriesClassification_module) {
class_<seriesClassification>("SeriesClassificationCpp") //name change so that I can wrap it in Javascript. -mz
.constructor()
//.function("addTrainingSet", &seriesClassification::addTrainingSet)
//.function("train", &seriesClassification::train)
.function("reset", &seriesClassification::reset)
.function("trainLabel", &seriesClassification::trainLabel)
.function("runLabel", &seriesClassification::runLabel)
//.function("runTrainingSet", &seriesClassification::runTrainingSet)
.function("train", &seriesClassification::train)
.function("run", &seriesClassification::run)
.function("getCosts", select_overload<std::vector<double>()>(&seriesClassification::getCosts))
;
......
......@@ -15,7 +15,7 @@ seriesClassification::seriesClassification() {};
seriesClassification::~seriesClassification() {};
bool seriesClassification::trainLabel(const std::vector<trainingSeries> &seriesSet) {
bool seriesClassification::train(const std::vector<trainingSeries> &seriesSet) {
reset();
bool trained = true;
allTrainingSeries = seriesSet;
......@@ -30,7 +30,7 @@ void seriesClassification::reset() {
allTrainingSeries.clear();
}
std::string seriesClassification::runLabel(const std::vector<std::vector<double>> &inputSeries) {
std::string seriesClassification::run(const std::vector<std::vector<double>> &inputSeries) {
dtw dtw;
int closestSeries = 0;
allCosts.clear();
......@@ -49,9 +49,10 @@ std::string seriesClassification::runLabel(const std::vector<std::vector<double>
};
//std::vector<double> seriesClassification::getCosts() {
// return allCosts;
//}
std::vector<double> seriesClassification::getCosts() {
return allCosts;
}
//
//std::vector<double> seriesClassification::getCosts(const std::vector<trainingExample> &trainingSet) {
// run(trainingSet);
......
......@@ -20,12 +20,10 @@ public:
seriesClassification();
~seriesClassification();
bool trainLabel(const std::vector<trainingSeries> &seriesSet);
bool train(const std::vector<trainingSeries> &seriesSet);
void reset();
std::string runLabel(const std::vector<std::vector<double>> &inputSeries);
std::string run(const std::vector<std::vector<double>> &inputSeries);
std::vector<double> getCosts();
private:
std::vector<trainingSeries> allTrainingSeries;
......
......@@ -339,7 +339,7 @@ describe('RapidLib Machine Learning', function () {
});
it('should report costs', function () {
expect(myDTW.getCosts()[0]).to.equal(17.325403217417502);
expect(myDTW.getCosts()[0]).to.equal(14.621232784634294);
expect(myDTW.getCosts()[1]).to.equal(0);
});
it('should report new costs');
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment