/*
 * machineLearning.cpp
 * Created by Michael Zbyszynski on 10 Jan 2016
 * Copyright © 2017 Goldsmiths. All rights reserved.
 */

#include "machineLearning.h"

RAPIDMIX_BEGIN_NAMESPACE

void trainingData2rapidLib (const trainingData &newTrainingData, std::vector<trainingExample> &trainingSet) {
    for (int h = 0; h < newTrainingData.trainingSet.size(); ++h) { //Go through every phrase
        
        for (int i = 0; i < newTrainingData.trainingSet[h].elements.size(); ++i) { //...and every element
            trainingExample tempExample;
            tempExample.input = newTrainingData.trainingSet[h].elements[i].input;
            if (newTrainingData.trainingSet[h].elements[i].output.size() > 0) {
                tempExample.output = newTrainingData.trainingSet[h].elements[i].output;
            } else {
                tempExample.output.push_back(double(h));
            }
            trainingSet.push_back(tempExample);
        }
    }
};

template<>
bool machineLearning<classification>::train(const trainingData &newTrainingData) {
    std::vector<trainingExample> trainingSet;
    labels.clear();
    for (int i = 0; i < newTrainingData.trainingSet.size(); ++i) {
        labels.push_back(newTrainingData.trainingSet[i].label);
    }
    trainingData2rapidLib(newTrainingData, trainingSet);
    return classification::train(trainingSet);
}

template<>
bool machineLearning<regression>::train(const trainingData &newTrainingData) {
    std::vector<trainingExample> trainingSet;
    trainingData2rapidLib(newTrainingData, trainingSet);
    return regression::train(trainingSet);
}

template<>
bool machineLearning<seriesClassification>::train(const trainingData &newTrainingData) {
    std::vector<trainingSeries> seriesSet;
    for (int i = 0; i < newTrainingData.trainingSet.size(); ++i) { //each phrase
        trainingSeries tempSeries;
        tempSeries.label = newTrainingData.trainingSet[i].label;
        for (int j = 0; j < newTrainingData.trainingSet[i].elements.size(); ++j) { //each element
            tempSeries.input.push_back(newTrainingData.trainingSet[i].elements[j].input);
        }
        seriesSet.push_back(tempSeries);
    }
    return seriesClassification::trainLabel(seriesSet);
}

template<>
std::string machineLearning<classification>::run(const std::vector<double> &inputVector, const std::string &label) {
    int classIndex = classification::run(inputVector)[0];
    return labels[classIndex];
};

template<>
std::string machineLearning<seriesClassification>::run(const std::vector<std::vector<double> > &inputSeries) {
    return seriesClassification::runLabel(inputSeries);
}


RAPIDMIX_END_NAMESPACE