Commit 9cac0c29 authored by mzed's avatar mzed
Browse files

finishing hiding templates with typedef

parent 13e9722d
...@@ -212,9 +212,9 @@ int main(int argc, const char * argv[]) { ...@@ -212,9 +212,9 @@ int main(int argc, const char * argv[]) {
//Testing with labels //Testing with labels
seriesClassification<double> myDTW; seriesClassification myDTW;
std::vector<trainingSeries<double> > seriesVector; std::vector<trainingSeries> seriesVector;
trainingSeries<double> tempSeries; trainingSeries tempSeries;
tempSeries.input.push_back( { 1., 5.} ); tempSeries.input.push_back( { 1., 5.} );
tempSeries.input.push_back( { 2., 4.} ); tempSeries.input.push_back( { 2., 4.} );
......
...@@ -17,13 +17,13 @@ ...@@ -17,13 +17,13 @@
#define SEARCH_RADIUS 1 #define SEARCH_RADIUS 1
template<typename T> template<typename T>
seriesClassification<T>::seriesClassification() {}; seriesClassificationTemplate<T>::seriesClassificationTemplate() {};
template<typename T> template<typename T>
seriesClassification<T>::~seriesClassification() {}; seriesClassificationTemplate<T>::~seriesClassificationTemplate() {};
template<typename T> template<typename T>
bool seriesClassification<T>::train(const std::vector<trainingSeries<T> > &seriesSet) { bool seriesClassificationTemplate<T>::train(const std::vector<trainingSeriesTemplate<T> > &seriesSet) {
assert(seriesSet.size() > 0); assert(seriesSet.size() > 0);
reset(); reset();
bool trained = true; bool trained = true;
...@@ -58,7 +58,7 @@ bool seriesClassification<T>::train(const std::vector<trainingSeries<T> > &serie ...@@ -58,7 +58,7 @@ bool seriesClassification<T>::train(const std::vector<trainingSeries<T> > &serie
}; };
template<typename T> template<typename T>
void seriesClassification<T>::reset() { void seriesClassificationTemplate<T>::reset() {
allCosts.clear(); allCosts.clear();
allTrainingSeries.clear(); allTrainingSeries.clear();
lengthsPerLabel.clear(); lengthsPerLabel.clear();
...@@ -67,7 +67,7 @@ void seriesClassification<T>::reset() { ...@@ -67,7 +67,7 @@ void seriesClassification<T>::reset() {
} }
template<typename T> template<typename T>
std::string seriesClassification<T>::run(const std::vector<std::vector<T>> &inputSeries) { std::string seriesClassificationTemplate<T>::run(const std::vector<std::vector<T>> &inputSeries) {
int closestSeries = 0; int closestSeries = 0;
allCosts.clear(); allCosts.clear();
T lowestCost = fastDTW<T>::getCost(inputSeries, allTrainingSeries[0].input, SEARCH_RADIUS); T lowestCost = fastDTW<T>::getCost(inputSeries, allTrainingSeries[0].input, SEARCH_RADIUS);
...@@ -85,7 +85,7 @@ std::string seriesClassification<T>::run(const std::vector<std::vector<T>> &inpu ...@@ -85,7 +85,7 @@ std::string seriesClassification<T>::run(const std::vector<std::vector<T>> &inpu
}; };
template<typename T> template<typename T>
T seriesClassification<T>::run(const std::vector<std::vector<T>> &inputSeries, std::string label) { T seriesClassificationTemplate<T>::run(const std::vector<std::vector<T>> &inputSeries, std::string label) {
int closestSeries = 0; int closestSeries = 0;
allCosts.clear(); allCosts.clear();
T lowestCost = std::numeric_limits<T>::max(); T lowestCost = std::numeric_limits<T>::max();
...@@ -103,17 +103,17 @@ T seriesClassification<T>::run(const std::vector<std::vector<T>> &inputSeries, s ...@@ -103,17 +103,17 @@ T seriesClassification<T>::run(const std::vector<std::vector<T>> &inputSeries, s
}; };
template<typename T> template<typename T>
std::vector<T> seriesClassification<T>::getCosts() const{ std::vector<T> seriesClassificationTemplate<T>::getCosts() const{
return allCosts; return allCosts;
} }
template<typename T> template<typename T>
int seriesClassification<T>::getMinLength() const{ int seriesClassificationTemplate<T>::getMinLength() const{
return minLength; return minLength;
} }
template<typename T> template<typename T>
int seriesClassification<T>::getMinLength(std::string label) const { int seriesClassificationTemplate<T>::getMinLength(std::string label) const {
int labelMinLength = -1; int labelMinLength = -1;
typename std::map<std::string, minMax<int> >::const_iterator it = lengthsPerLabel.find(label); typename std::map<std::string, minMax<int> >::const_iterator it = lengthsPerLabel.find(label);
if (it != lengthsPerLabel.end()) { if (it != lengthsPerLabel.end()) {
...@@ -123,12 +123,12 @@ int seriesClassification<T>::getMinLength(std::string label) const { ...@@ -123,12 +123,12 @@ int seriesClassification<T>::getMinLength(std::string label) const {
} }
template<typename T> template<typename T>
int seriesClassification<T>::getMaxLength() const { int seriesClassificationTemplate<T>::getMaxLength() const {
return maxLength; return maxLength;
} }
template<typename T> template<typename T>
int seriesClassification<T>::getMaxLength(std::string label) const { int seriesClassificationTemplate<T>::getMaxLength(std::string label) const {
int labelMaxLength = -1; int labelMaxLength = -1;
typename std::map<std::string, minMax<int> >::const_iterator it = lengthsPerLabel.find(label); typename std::map<std::string, minMax<int> >::const_iterator it = lengthsPerLabel.find(label);
if (it != lengthsPerLabel.end()) { if (it != lengthsPerLabel.end()) {
...@@ -138,7 +138,7 @@ int seriesClassification<T>::getMaxLength(std::string label) const { ...@@ -138,7 +138,7 @@ int seriesClassification<T>::getMaxLength(std::string label) const {
} }
template<typename T> template<typename T>
seriesClassification<T>::minMax<T> seriesClassification<T>::calculateCosts(std::string label) const { seriesClassificationTemplate<T>::minMax<T> seriesClassificationTemplate<T>::calculateCosts(std::string label) const {
minMax<T> calculatedMinMax; minMax<T> calculatedMinMax;
bool foundSeries = false; bool foundSeries = false;
std::vector<T> labelCosts; std::vector<T> labelCosts;
...@@ -163,7 +163,7 @@ seriesClassification<T>::minMax<T> seriesClassification<T>::calculateCosts(std:: ...@@ -163,7 +163,7 @@ seriesClassification<T>::minMax<T> seriesClassification<T>::calculateCosts(std::
} }
template<typename T> template<typename T>
seriesClassification<T>::minMax<T> seriesClassification<T>::calculateCosts(std::string label1, std::string label2) const { seriesClassificationTemplate<T>::minMax<T> seriesClassificationTemplate<T>::calculateCosts(std::string label1, std::string label2) const {
minMax<T> calculatedMinMax; minMax<T> calculatedMinMax;
bool foundSeries = false; bool foundSeries = false;
std::vector<T> labelCosts; std::vector<T> labelCosts;
...@@ -188,8 +188,8 @@ seriesClassification<T>::minMax<T> seriesClassification<T>::calculateCosts(std:: ...@@ -188,8 +188,8 @@ seriesClassification<T>::minMax<T> seriesClassification<T>::calculateCosts(std::
} }
//explicit instantiation //explicit instantiation
template class seriesClassification<double>; template class seriesClassificationTemplate<double>;
template class seriesClassification<float>; template class seriesClassificationTemplate<float>;
// //
......
...@@ -22,17 +22,17 @@ ...@@ -22,17 +22,17 @@
*/ */
template<typename T> template<typename T>
class seriesClassification { class seriesClassificationTemplate {
public: public:
/** Constructor, no params */ /** Constructor, no params */
seriesClassification(); seriesClassificationTemplate();
~seriesClassification(); ~seriesClassificationTemplate();
/** Train on a specified set of trainingSeries /** Train on a specified set of trainingSeries
* @param std::vector<trainingSeries> A vector of training series * @param std::vector<trainingSeries> A vector of training series
*/ */
bool train(const std::vector<trainingSeries<T> > &seriesSet); bool train(const std::vector<trainingSeriesTemplate<T> > &seriesSet);
/** Reset model to its initial state, forget all costs and training data*/ /** Reset model to its initial state, forget all costs and training data*/
void reset(); void reset();
...@@ -98,11 +98,14 @@ public: ...@@ -98,11 +98,14 @@ public:
minMax<T> calculateCosts(std::string label1, std::string label2) const; minMax<T> calculateCosts(std::string label1, std::string label2) const;
private: private:
std::vector<trainingSeries<T > > allTrainingSeries; std::vector<trainingSeriesTemplate<T> > allTrainingSeries;
std::vector<T> allCosts; std::vector<T> allCosts;
int maxLength; int maxLength;
int minLength; int minLength;
std::map<std::string, minMax<int> > lengthsPerLabel; std::map<std::string, minMax<int> > lengthsPerLabel;
}; };
typedef seriesClassificationTemplate<double> seriesClassification; //This is here to keep the old API working
typedef seriesClassificationTemplate<float> seriesClassificationFloat;
#endif #endif
...@@ -25,9 +25,12 @@ typedef trainingExampleTemplate<float> trainingExampleFloat; ...@@ -25,9 +25,12 @@ typedef trainingExampleTemplate<float> trainingExampleFloat;
/** This is used by DTW models for training */ /** This is used by DTW models for training */
template<typename T> template<typename T>
struct trainingSeries { struct trainingSeriesTemplate {
std::vector<std::vector<T> > input; std::vector<std::vector<T> > input;
std::string label; std::string label;
}; };
typedef trainingSeriesTemplate<double> trainingSeries;
typedef trainingSeriesTemplate<float> trainingSeriesFloat;
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment