Commit c2a38f4d authored by mzed's avatar mzed
Browse files

template for seriesClassification

parent 53d20fc9
......@@ -216,7 +216,7 @@ int main(int argc, const char * argv[]) {
//Testing with labels
seriesClassification myDTW;
seriesClassification<double> myDTW;
std::vector<trainingSeries<double> > seriesVector;
trainingSeries<double> tempSeries;
......
......@@ -15,11 +15,14 @@
#define SEARCH_RADIUS 1
seriesClassification::seriesClassification() {};
template<typename T>
seriesClassification<T>::seriesClassification() {};
seriesClassification::~seriesClassification() {};
template<typename T>
seriesClassification<T>::~seriesClassification() {};
bool seriesClassification::train(const std::vector<trainingSeries<double> > &seriesSet) {
template<typename T>
bool seriesClassification<T>::train(const std::vector<trainingSeries<T> > &seriesSet) {
assert(seriesSet.size() > 0);
reset();
bool trained = true;
......@@ -35,7 +38,7 @@ bool seriesClassification::train(const std::vector<trainingSeries<double> > &ser
maxLength = newLength;
}
//Per Label
std::map<std::string, minMax<int> >::iterator it = lengthsPerLabel.find(allTrainingSeries[i].label);
typename std::map<std::string, minMax<int> >::iterator it = lengthsPerLabel.find(allTrainingSeries[i].label);
if (it != lengthsPerLabel.end()) {
int newLength = int(allTrainingSeries[i].input.size());
if (newLength < it->second.min) {
......@@ -53,7 +56,8 @@ bool seriesClassification::train(const std::vector<trainingSeries<double> > &ser
return trained;
};
void seriesClassification::reset() {
template<typename T>
void seriesClassification<T>::reset() {
allCosts.clear();
allTrainingSeries.clear();
lengthsPerLabel.clear();
......@@ -61,14 +65,15 @@ void seriesClassification::reset() {
maxLength = -1;
}
std::string seriesClassification::run(const std::vector<std::vector<double>> &inputSeries) {
template<typename T>
std::string seriesClassification<T>::run(const std::vector<std::vector<T>> &inputSeries) {
int closestSeries = 0;
allCosts.clear();
double lowestCost = fastDTW<double>::getCost(inputSeries, allTrainingSeries[0].input, SEARCH_RADIUS);
T lowestCost = fastDTW<T>::getCost(inputSeries, allTrainingSeries[0].input, SEARCH_RADIUS);
allCosts.push_back(lowestCost);
for (int i = 1; i < allTrainingSeries.size(); ++i) {
double currentCost = fastDTW<double>::getCost(inputSeries, allTrainingSeries[i].input, SEARCH_RADIUS);
T currentCost = fastDTW<T>::getCost(inputSeries, allTrainingSeries[i].input, SEARCH_RADIUS);
allCosts.push_back(currentCost);
if (currentCost < lowestCost) {
lowestCost = currentCost;
......@@ -79,40 +84,46 @@ std::string seriesClassification::run(const std::vector<std::vector<double>> &in
};
std::vector<double> seriesClassification::getCosts() {
template<typename T>
std::vector<T> seriesClassification<T>::getCosts() {
return allCosts;
}
int seriesClassification::getMinLength() {
template<typename T>
int seriesClassification<T>::getMinLength() {
return minLength;
}
int seriesClassification::getMinLength(std::string label) {
template<typename T>
int seriesClassification<T>::getMinLength(std::string label) {
int labelMinLength = -1;
std::map<std::string, minMax<int> >::iterator it = lengthsPerLabel.find(label);
typename std::map<std::string, minMax<int> >::iterator it = lengthsPerLabel.find(label);
if (it != lengthsPerLabel.end()) {
labelMinLength = it->second.min;
}
return labelMinLength;
}
int seriesClassification::getMaxLength() {
template<typename T>
int seriesClassification<T>::getMaxLength() {
return maxLength;
}
int seriesClassification::getMaxLength(std::string label) {
template<typename T>
int seriesClassification<T>::getMaxLength(std::string label) {
int labelMaxLength = -1;
std::map<std::string, minMax<int> >::iterator it = lengthsPerLabel.find(label);
typename std::map<std::string, minMax<int> >::iterator it = lengthsPerLabel.find(label);
if (it != lengthsPerLabel.end()) {
labelMaxLength = it->second.max;
}
return labelMaxLength;
}
seriesClassification::minMax<double> seriesClassification::calculateCosts(std::string label) {
minMax<double> calculatedMinMax;
calculatedMinMax.min = std::numeric_limits<double>::max();
calculatedMinMax.max = std::numeric_limits<double>::min();
template<typename T>
seriesClassification<T>::minMax<T> seriesClassification<T>::calculateCosts(std::string label) {
minMax<T> calculatedMinMax;
calculatedMinMax.min = std::numeric_limits<T>::max();
calculatedMinMax.max = std::numeric_limits<T>::min();
int numSeries = 0;
for (int i = 0; i < (allTrainingSeries.size() - 1); ++i) { //these loops are a little different than the two-label case
......@@ -120,7 +131,7 @@ seriesClassification::minMax<double> seriesClassification::calculateCosts(std::s
for (int j = (i + 1); j < allTrainingSeries.size(); ++j) {
if (allTrainingSeries[j].label == label) {
numSeries++;
double currentCost = fastDTW<double>::getCost(allTrainingSeries[i].input, allTrainingSeries[j].input, SEARCH_RADIUS);
T currentCost = fastDTW<T>::getCost(allTrainingSeries[i].input, allTrainingSeries[j].input, SEARCH_RADIUS);
if (numSeries == 1) {
calculatedMinMax.min = calculatedMinMax.max = currentCost; //first match is both min and max
} else {
......@@ -141,10 +152,11 @@ seriesClassification::minMax<double> seriesClassification::calculateCosts(std::s
return calculatedMinMax;
}
seriesClassification::minMax<double> seriesClassification::calculateCosts(std::string label1, std::string label2) {
minMax<double> calculatedMinMax;
calculatedMinMax.min = std::numeric_limits<double>::max();
calculatedMinMax.max = std::numeric_limits<double>::min();
template<typename T>
seriesClassification<T>::minMax<T> seriesClassification<T>::calculateCosts(std::string label1, std::string label2) {
minMax<T> calculatedMinMax;
calculatedMinMax.min = std::numeric_limits<T>::max();
calculatedMinMax.max = std::numeric_limits<T>::min();
int numSeries = 0;
for (int i = 0; i < (allTrainingSeries.size()); ++i) {
......@@ -152,7 +164,7 @@ seriesClassification::minMax<double> seriesClassification::calculateCosts(std::s
for (int j = 0; j < allTrainingSeries.size(); ++j) {
if (allTrainingSeries[j].label == label2) {
numSeries++;
double currentCost = fastDTW<double>::getCost(allTrainingSeries[i].input, allTrainingSeries[j].input, SEARCH_RADIUS);
T currentCost = fastDTW<T>::getCost(allTrainingSeries[i].input, allTrainingSeries[j].input, SEARCH_RADIUS);
if (numSeries == 1) {
calculatedMinMax.min = calculatedMinMax.max = currentCost; //first match is both min and max
} else {
......@@ -170,9 +182,13 @@ seriesClassification::minMax<double> seriesClassification::calculateCosts(std::s
return calculatedMinMax;
}
//explicit instantiation
template class seriesClassification<double>;
template class seriesClassification<float>;
//
//std::vector<double> seriesClassification::getCosts(const std::vector<trainingExample> &trainingSet) {
//std::vector<T> seriesClassification::getCosts(const std::vector<trainingExample> &trainingSet) {
// run(trainingSet);
// return allCosts;
//}
\ No newline at end of file
......@@ -15,34 +15,35 @@
#include "fastDTW.h"
#include "trainingExample.h"
template<typename T>
class seriesClassification {
public:
seriesClassification();
~seriesClassification();
bool train(const std::vector<trainingSeries<double> > &seriesSet);
bool train(const std::vector<trainingSeries<T> > &seriesSet);
void reset();
std::string run(const std::vector<std::vector<double>> &inputSeries);
std::vector<double> getCosts();
std::string run(const std::vector<std::vector<T>> &inputSeries);
std::vector<T> getCosts();
int getMinLength();
int getMinLength(std::string label);
int getMaxLength();
int getMaxLength(std::string label);
template<typename T>
template<typename TT>
struct minMax {
T min;
T max;
TT min;
TT max;
};
minMax<double> calculateCosts(std::string label);
minMax<double> calculateCosts(std::string label1, std::string label2);
minMax<T> calculateCosts(std::string label);
minMax<T> calculateCosts(std::string label1, std::string label2);
private:
std::vector<trainingSeries<double > > allTrainingSeries;
std::vector<double> allCosts;
std::vector<trainingSeries<T > > allTrainingSeries;
std::vector<T> allCosts;
int maxLength;
int minLength;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment