/*
  Author: Xuye Luo
  Date: December 12, 2025
*/

#include <Rcpp.h>
#include <unordered_map>
using namespace Rcpp;

// [[Rcpp::interfaces(r, cpp)]]
/*
//' @title Fast Pearson's Chi-squared Statistic (C++ Backend)
//' @description Calculates the Chi-squared statistic directly from raw data vectors
//' using hash maps for efficient counting.
//' @param x Numeric vector of the first variable.
//' @param y Numeric vector of the second variable.
//' @return A List containing the statistic, sample size (n), row count (nr), and col count (nc).
*/
// [[Rcpp::export]]
List chisq_cpp(const NumericVector &x, const NumericVector &y) {
  
  int n = x.size();
  
  // Input Validation
  if (n != y.size()) {
    stop("Lengths of 'x' and 'y' must match.");
  }
  
  if (n == 0) {
    return List::create(Named("statistic") = 0, 
                        Named("n")  = 0,
                        Named("nr") = 0,
                        Named("nc") = 0);
  }

  // Build Contingency Table using Hash Maps
  // Key: Value in vector, Value: Count
  std::unordered_map<double, std::unordered_map<double, double>> observed;
  std::unordered_map<double, double> row_sum;
  std::unordered_map<double, double> col_sum;
  
  for (int i = 0; i < n; i++) {
    double val_x = x[i];
    double val_y = y[i];
    
    observed[val_x][val_y]++;
    row_sum[val_x]++;
    col_sum[val_y]++;
  }
  
  int nr = row_sum.size();
  int nc = col_sum.size();
  

  if (nr < 2 || nc < 2) {
    return List::create(Named("statistic") = 0, 
                        Named("n")  = n,
                        Named("nr") = nr,
                        Named("nc") = nc);
  }

  // Calculate Statistic
  // Formula: Sum( (O - E)^2 / E ) = Sum( O^2 / E ) - N
  double sum_sq_O_over_E = 0.0;
  double N_dbl = (double)n;

  for (auto const& row : observed) {
    double r_sum = row_sum[row.first]; // Row marginal
    
    for (auto const& cell : row.second) {
      double O = cell.second; // Observed count
      double c_sum = col_sum[cell.first]; // Col marginal
      
      // Expected = (RowSum * ColSum) / N
      double E = (r_sum * c_sum) / N_dbl;
      
      // Accumulate O^2 / E
      sum_sq_O_over_E += (O * O) / E;
    }
  }

  double statistic = sum_sq_O_over_E - N_dbl;

  if (statistic < 0) statistic = 0;

  return List::create(Named("statistic") = statistic, 
                      Named("n")  = n,
                      Named("nr") = nr,
                      Named("nc") = nc);
}
