//
//  rapidGVF.cpp
//
//  Created by Francisco on 04/05/2017.
//  Copyright © 2017 Goldsmiths. All rights reserved.
//

#include "rapidGVF.h"
#include "../trainingData.h"

rapidGVF::rapidGVF() {}

rapidGVF::~rapidGVF() {}

bool rapidGVF::train(const rapidmix::trainingData &newTrainingData)
{
    
    if (newTrainingData.trainingSet.size() < 1)
    {
        // no recorded phrase
        return false;
    }
    
    if (newTrainingData.trainingSet.size() == 1 && newTrainingData.trainingSet[0].elements.size() == 0) {
        // empty recorded phrase
        return false;
    }
    
    if(gvf.getState() != GVF::STATE_LEARNING)
    {
        gvf.setState(GVF::STATE_LEARNING);
    }
    
    //Go through every phrase
    for (int h = 0; h < newTrainingData.trainingSet.size(); ++h)
    {
        gvf.startGesture();
        for (int i = 0; i < newTrainingData.trainingSet[h].elements.size(); ++i)
        {
            std::vector<double> vd = newTrainingData.trainingSet[h].elements[i].input;

            // Using template <class InputIterator> vector to change for vec<double> to vec<float>
            std::vector<float> vf(vd.begin(), vd.end());
            this->currentGesture.addObservation(vf);
        }
        gvf.addGestureTemplate(this->currentGesture);
    }
    return true;
}

std::vector<double> rapidGVF::run(const std::vector<double> &inputVector)
{
    
    if (inputVector.size() == 0)
    {
        return std::vector<double>();
    }
    
    gvf.restart();
    
    if (gvf.getState() != GVF::STATE_FOLLOWING)
    {
        gvf.setState(GVF::STATE_FOLLOWING);
    }
    
    // Using template <class InputIterator> vector to change for vec<double> to vec<float>
    std::vector<float> vf(inputVector.begin(),inputVector.end());
    
    this->currentGesture.addObservation(vf);
    outcomes = gvf.update(this->currentGesture.getLastObservation());
    
    std::vector<double> output;
    output.push_back(outcomes.likeliestGesture);
    output.insert(output.end(), outcomes.likelihoods.begin(), outcomes.likelihoods.end());
    output.insert(output.end(), outcomes.alignments.begin(), outcomes.alignments.end());
    return output;
}

const std::vector<float> rapidGVF::getLikelihoods()
{
    return outcomes.likelihoods;
};

const std::vector<float> rapidGVF::getAlignments()
{
    return outcomes.alignments;
};

const std::vector<std::vector<float> > * rapidGVF::getDynamics()
{
    return &outcomes.dynamics;
};

const std::vector<std::vector<float> > * rapidGVF::getScalings()
{
    return &outcomes.scalings;
};

const std::vector<std::vector<float> > * rapidGVF::getRotations()
{
    return &outcomes.rotations;
};