/*-------------------------------------------------------------------------------
 This file is part of unityForest.

 Copyright (c) [2014-2018] [Marvin N. Wright]
 Modifications and extensions by Roman Hornung

 This software may be modified and distributed under the terms of the MIT license.

 Please note that the C++ core of divfor is distributed under MIT license and the
 R package "unityForest" under GPL3 license.
 #-------------------------------------------------------------------------------*/

#include <Rcpp.h>
#include <unordered_map>
#include <algorithm>
#include <iterator>
#include <random>
#include <stdexcept>
#include <cmath>
#include <string>

#include "utility.h"
#include "ForestClassification.h"
#include "TreeClassification.h"
#include "Data.h"

namespace unityForest
{

  void ForestClassification::loadForest(size_t dependent_varID, size_t num_trees,
                                        std::vector<std::vector<std::vector<size_t>>> &forest_child_nodeIDs,
                                        std::vector<std::vector<size_t>> &forest_split_varIDs, std::vector<std::vector<double>> &forest_split_values,
                                        std::vector<double> &class_values, std::vector<bool> &is_ordered_variable)
  {

    this->dependent_varID = dependent_varID;
    this->num_trees = num_trees;
    this->class_values = class_values;
    data->setIsOrderedVariable(is_ordered_variable);

    // Create trees
    trees.reserve(num_trees);
    for (size_t i = 0; i < num_trees; ++i)
    {
      trees.push_back(
          std::make_unique<TreeClassification>(forest_child_nodeIDs[i], forest_split_varIDs[i], forest_split_values[i],
                                               &this->class_values, &response_classIDs));
    }

    // Create thread ranges
    equalSplit(thread_ranges, 0, num_trees - 1, num_threads);
  }

  // Function for loading a saved forest for the CRTR analysis 
  void ForestClassification::loadForestRepr(size_t dependent_varID, size_t num_trees,
                                            std::vector<std::vector<std::vector<size_t>>> &forest_child_nodeIDs,
                                            std::vector<std::vector<size_t>> &forest_split_varIDs, std::vector<std::vector<double>> &forest_split_values,
                                            std::vector<double> &class_values, std::vector<double> &class_weights, std::vector<std::vector<size_t>> &forest_nodeID_in_root,
                                            std::vector<std::vector<size_t>> &forest_inbag_counts,
                                            std::vector<bool> &is_ordered_variable)
  {

    this->dependent_varID = dependent_varID;
    this->num_trees = num_trees;
    this->class_values = class_values;
    this->class_weights = class_weights;
    data->setIsOrderedVariable(is_ordered_variable);

    /* build response_classIDs exactly like initInternal ------------------- */
    response_classIDs.clear();
    response_classIDs.reserve(data->getNumRows());

    for (size_t row = 0; row < data->getNumRows(); ++row)
    {
      double y = data->get(row, dependent_varID);

      // find() returns iterator; subtracting gives the index
      uint classID = std::find(class_values.begin(),
                               class_values.end(),
                               y) -
                     class_values.begin();

      // if not found, append and classID becomes new last index
      if (classID == class_values.size())
        class_values.push_back(y);

      response_classIDs.push_back(classID);
    }
    /* -------------------------------------------------------------------- */

    // Create trees
    trees.reserve(num_trees);
    for (size_t i = 0; i < num_trees; ++i)
    {
      trees.push_back(
          std::make_unique<TreeClassification>(forest_child_nodeIDs[i], forest_split_varIDs[i], forest_split_values[i],
                                               &this->class_values, &this->class_weights, &response_classIDs, forest_nodeID_in_root[i], forest_inbag_counts[i], this->repr_vars, data.get()));
    }

    // Create thread ranges
    equalSplit(thread_ranges, 0, num_trees - 1, num_threads);
  }

  void ForestClassification::initInternal(std::string status_variable_name)
  {

    // If mtry not set, use floored square root of number of independent variables.
    if (mtry == 0)
    {
      unsigned long temp = sqrt((double)(num_variables - 1));
      mtry = std::max((unsigned long)1, temp);
    }

    // Set minimal node size
    if (min_node_size == 0)
    {
      min_node_size = DEFAULT_MIN_NODE_SIZE_CLASSIFICATION;
    }

    // Create class_values and response_classIDs
    if (!prediction_mode)
    {
      for (size_t i = 0; i < num_samples; ++i)
      {
        double value = data->get(i, dependent_varID);

        // If classID is already in class_values, use ID. Else create a new one.
        uint classID = find(class_values.begin(), class_values.end(), value) - class_values.begin();
        if (classID == class_values.size())
        {
          class_values.push_back(value);
        }
        response_classIDs.push_back(classID);
      }
    }

    // Create sampleIDs_per_class if required
    if (sample_fraction.size() > 1)
    {
      sampleIDs_per_class.resize(sample_fraction.size());
      for (auto &v : sampleIDs_per_class)
      {
        v.reserve(num_samples);
      }
      for (size_t i = 0; i < num_samples; ++i)
      {
        size_t classID = response_classIDs[i];
        sampleIDs_per_class[classID].push_back(i);
      }
    }

    // Set class weights all to 1
    class_weights = std::vector<double>(class_values.size(), 1.0);

    // Sort data if memory saving mode
    if (!memory_saving_splitting)
    {
      data->sort();
    }
  }

  void ForestClassification::growInternal()
  {
    trees.reserve(num_trees);
    for (size_t i = 0; i < num_trees; ++i)
    {
      trees.push_back(
          std::make_unique<TreeClassification>(&class_values, &response_classIDs, &sampleIDs_per_class, &class_weights));
      // -----------------------------------------------------------
      // give the freshly-built tree the pre-computed variable list
      // -----------------------------------------------------------
      trees.back()->setAllowedVarIDs(&allowedVarIDs_);
    }
  }

  void ForestClassification::allocatePredictMemory()
  {
    size_t num_prediction_samples = data->getNumRows();
    if (predict_all || prediction_type == TERMINALNODES)
    {
      predictions = std::vector<std::vector<std::vector<double>>>(1,
                                                                  std::vector<std::vector<double>>(num_prediction_samples, std::vector<double>(num_trees)));
    }
    else
    {
      predictions = std::vector<std::vector<std::vector<double>>>(1,
                                                                  std::vector<std::vector<double>>(1, std::vector<double>(num_prediction_samples)));
    }
  }

  void ForestClassification::predictInternal(size_t sample_idx)
  {
    if (predict_all || prediction_type == TERMINALNODES)
    {
      // Get all tree predictions
      for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx)
      {
        if (prediction_type == TERMINALNODES)
        {
          predictions[0][sample_idx][tree_idx] = getTreePredictionTerminalNodeID(tree_idx, sample_idx);
        }
        else
        {
          predictions[0][sample_idx][tree_idx] = getTreePrediction(tree_idx, sample_idx);
        }
      }
    }
    else
    {
      // Count classes over trees and save class with maximum count
      std::unordered_map<double, size_t> class_count;
      for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx)
      {
        ++class_count[getTreePrediction(tree_idx, sample_idx)];
      }
      predictions[0][0][sample_idx] = mostFrequentValue(class_count, random_number_generator);
    }
  }

  void ForestClassification::computePredictionErrorInternal()
  {

    // CP();
    //  Class counts for samples
    std::vector<std::unordered_map<double, size_t>> class_counts;
    class_counts.reserve(num_samples);
    for (size_t i = 0; i < num_samples; ++i)
    {
      class_counts.push_back(std::unordered_map<double, size_t>());
    }

    // For each tree loop over OOB samples and count classes
    for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx)
    {
      for (size_t sample_idx = 0; sample_idx < trees[tree_idx]->getNumSamplesOob(); ++sample_idx)
      {
        size_t sampleID = trees[tree_idx]->getOobSampleIDs()[sample_idx];
        ++class_counts[sampleID][getTreePrediction(tree_idx, sample_idx)];
      }
    }

    // Compute majority vote for each sample
    predictions = std::vector<std::vector<std::vector<double>>>(1,
                                                                std::vector<std::vector<double>>(1, std::vector<double>(num_samples)));
    for (size_t i = 0; i < num_samples; ++i)
    {
      if (!class_counts[i].empty())
      {
        predictions[0][0][i] = mostFrequentValue(class_counts[i], random_number_generator);
      }
      else
      {
        predictions[0][0][i] = NAN;
      }
    }

    // Compare predictions with true data
    size_t num_missclassifications = 0;
    size_t num_predictions = 0;
    for (size_t i = 0; i < predictions[0][0].size(); ++i)
    {
      double predicted_value = predictions[0][0][i];
      if (!std::isnan(predicted_value))
      {
        ++num_predictions;
        double real_value = data->get(i, dependent_varID);
        if (predicted_value != real_value)
        {
          ++num_missclassifications;
        }
        ++classification_table[std::make_pair(real_value, predicted_value)];
      }
    }
    overall_prediction_error = (double)num_missclassifications / (double)num_predictions;
  }

  // #nocov start
  double ForestClassification::getTreePrediction(size_t tree_idx, size_t sample_idx) const
  {
    const auto &tree = dynamic_cast<const TreeClassification &>(*trees[tree_idx]);
    return tree.getPrediction(sample_idx);
  }

  size_t ForestClassification::getTreePredictionTerminalNodeID(size_t tree_idx, size_t sample_idx) const
  {
    const auto &tree = dynamic_cast<const TreeClassification &>(*trees[tree_idx]);
    return tree.getPredictionTerminalNodeID(sample_idx);
  }

  // #nocov end

} // namespace unityForest
