#' MATES
#'
#' @aliases MATES-package
"_PACKAGE"

#' Find the permutation mean
#'
#' This function takes a list of numeric matrices and uses a C++ backend to
#' find the permutation mean.
#'
#' @param R_list A list of numeric matrices with length S
#' @param m An integer representing the number of sample in X
#' @param n An integer representing the number of sample in Y
#' @return A numeric vector with length 2*S
#' @export
#' @examples
#' \donttest{
#' # Generate simulated data
#' set.seed(123)
#' X <- matrix(rnorm(20), ncol = 2)  # 10 samples, 2 dimensions
#' Y <- matrix(rnorm(20), ncol = 2)  # 10 samples, 2 dimensions
#' Z <- rbind(X, Y)
#' m <- nrow(X)
#' n <- nrow(Y)
#' N <- m + n
#'
#' # Compute distance and similarity matrices
#' D <- as.matrix(dist(Z, method = "manhattan"))
#' S <- max(D) - D
#'
#' # Compute rank matrix (simplified NNG approach)
#' R <- matrix(0, N, N)
#' k <- 3
#' for(i in 1:N) {
#'   neighbors <- order(D[i,])[2:(k+1)]  # k nearest neighbors
#'   R[i, neighbors] <- 1:k
#' }
#' R <- R + t(R)
#'
#' # Create list with one rank matrix
#' R_list <- list(R)
#'
#' # Calculate permutation mean
#' mean_vec <- asy_mean(R_list, m = m, n = n)
#' print(mean_vec)
#' }
#' @useDynLib MATES, .registration = TRUE
#' @importFrom Rcpp sourceCpp
asy_mean <- function(R_list, m, n){
  asymp_mean_rcpp(R_list, m, n)
}

#' Find the permutation covariance
#'
#' This function takes a list of numeric matrices and uses a C++ backend to
#' find the permutation covariance
#'
#' @param R_list A list of numeric matrices with length S
#' @param m An integer representing the number of sample in X
#' @param n An integer representing the number of sample in Y
#' @return A numeric matrix with row and column 2*S
#' @export
#' @examples
#' \donttest{
#' # Generate simulated data
#' set.seed(123)
#' X <- matrix(rnorm(20), ncol = 2)  # 10 samples, 2 dimensions
#' Y <- matrix(rnorm(20), ncol = 2)  # 10 samples, 2 dimensions
#' Z <- rbind(X, Y)
#' m <- nrow(X)
#' n <- nrow(Y)
#' N <- m + n
#'
#' # Compute distance and similarity matrices
#' D <- as.matrix(dist(Z, method = "manhattan"))
#' S <- max(D) - D
#'
#' # Compute rank matrix (simplified NNG approach)
#' R <- matrix(0, N, N)
#' k <- 3
#' for(i in 1:N) {
#'   neighbors <- order(D[i,])[2:(k+1)]  # k nearest neighbors
#'   R[i, neighbors] <- 1:k
#' }
#' R <- R + t(R)
#'
#' # Create list with one rank matrix
#' R_list <- list(R)
#'
#' # Calculate permutation covariance
#' cov_mat <- asy_cov(R_list, m = m, n = n)
#' print(cov_mat)
#' }
#' @useDynLib MATES, .registration = TRUE
#' @importFrom Rcpp sourceCpp
asy_cov <- function(R_list, m, n){
  asymp_cov_rcpp(R_list, m, n)
}

#' RISE rank matrix
#'
#' The rank function to calculate rank of elements of a matrix. Two possible
#' methods: the overall rank and the row-wise rank.
#'
#' @param S A numeric matrix representing the similarity matrix
#' @param method A character string indicating the ranking method to use
#'               ("overall" or "row")
#' @return A numeric matrix representing the rank matrix
#' @useDynLib MATES, .registration = TRUE
#' @importFrom Rcpp sourceCpp
Rise_Rank<-function(S,method='overall'){
  #S: a N X N similarity matrix
  #R: the output of N X N rank matrix
  R = diag(0,nrow(S))
  if(method=='overall'){
    R[upper.tri(R)] = rank(S[upper.tri(S)])
    R = R + t(R)
  }
  if(method=='row'){
    diag(S) = min(S) - 100
    R = t(apply(S, 1, rank)) - 1
  }
  return(R)
}

#' Compute k-rNNG graph
#'
#' This function builds penalized K nearest neighbor graphs with rank
#' The output is a list containing the graph and the degree distribution
#' @param M A numeric matrix representing the distance matrix
#' @param K An integer representing the number of neighbors to use
#' @param lambda A numeric representing the penalty parameter
#' @return A list containing the truncated KNN graph (trun_KNN) and the degree
#'         distribution (degree)
#' @references Zhu, Y., & Chen, H. (2023). A new robust graph for graph-based methods. \emph{arXiv preprint arXiv:2307.15205.}
#' @useDynLib MATES, .registration = TRUE
#' @importFrom Rcpp sourceCpp
P_Knear_rank = function(M,K=round(nrow(M)^0.8),lambda=0.3){
  diag(M) = NA
  Morder = apply(M, 1, order)
  N = nrow(M)
  Morder = Morder[-N,]
  KNN = cbind(rep(1:nrow(M),each=K),c(Morder[1:K,]))

  ## rank matrix
  rank_matrix = apply(M, 1, rank)

  out_nodes = Out_direct(KNN,1:N)
  degree = degree_distribution(KNN,1:N)

  is_loop = 1
  nodes = sample(1:N,N)
  id = 1
  while (is_loop<=N) {
    if(id>N){
      id=1
      nodes = sample(1:N,N)
    }
    neighbor = Morder[,nodes[id]]
    cur_neis = out_nodes[[nodes[id]]]
    op = optimalwithrank_curnode(K,cur_neis,neighbor,degree,lambda,rank_matrix[,nodes[id]])
    new_neis = op$new_neis
    if(setequal(cur_neis,new_neis)){
      is_loop=is_loop+1
      id=id+1
    }
    else{
      degree = op$degree
      out_nodes[[nodes[id]]] = op$new_neis
      is_loop=1
      id=id+1
    }
  }

  trun_KNN = NULL
  for (i in 1:N) {
    trun_KNN = c(trun_KNN,rep(i,length(out_nodes[[i]])))
  }
  trun_KNN = cbind(trun_KNN,unlist(out_nodes))
  colnames(trun_KNN) = c('V1','V2')

  return(list(trun_KNN = trun_KNN,degree=degree))
}

#' This function is used in 'P_Knear_rank'
#' #' Compute k-rNNG graph
#'
#' This function builds one-step neighbor update for penalized K nearest neighbor graphs with rank
#' The output is a list containing the graph and the degree distribution
#' @param k Integer; desired number of neighbors (out-degree) for the current node
#' @param cur_neis Integer vector of current neighbors of the node
#' @param neighbor Integer vector of candidate neighbor indices for this node,
#' @param degree Numeric vector of current degrees for all nodes
#' @param lambda A numeric representing the penalty parameter
#' @param rowrank Numeric vector of rank-based penalties for this node
#' @return A list with two elements:
#' \describe{
#'   \item{new_neis}{Integer vector of length \code{k} giving the updated
#'     neighbors of the node.}
#'   \item{degree}{Updated numeric degree vector for all nodes.}
#' }
#' @references Zhu, Y., & Chen, H. (2023). A new robust graph for graph-based methods. \emph{arXiv preprint arXiv:2307.15205.}
#' @useDynLib MATES, .registration = TRUE
#' @importFrom Rcpp sourceCpp
optimalwithrank_curnode = function(k,cur_neis,neighbor,degree,lambda,rowrank){
  l = length(neighbor)
  degree[cur_neis] = degree[cur_neis]-1
  loss = numeric(l)
  for (i in 1:l) {
    new_connect = rep(0,length(degree))
    new_connect[neighbor[i]] = 1
    loss[i] = rowrank[neighbor[i]]+lambda*sum((degree+new_connect-2*k)^2)
  }
  new_neis =neighbor[order(loss)[1:k]]
  degree[new_neis] =  degree[new_neis] +1
  return(list(new_neis = new_neis,degree = degree))
}

#' Auxiliary function to compute rank matrix
#'
#' This function is used in 'P_Knear_rank' to compute the degrees
#' @param G Integer or numeric matrix with two columns, where each row
#'        represents a directed edge \code{(from, to)} in the k-NN graph
#' @param sampleIDs Integer vector of node indices for which to compute degrees
#' @return Numeric vector of degrees with the same length and order as \code{sampleIDs}
degree_distribution = function(G,sampleIDs){
  degrees = numeric(length(sampleIDs))
  for (i in 1:(length(sampleIDs))) {
    degrees[i] = sum(G[,1]==sampleIDs[i])+sum(G[,2]==sampleIDs[i])
  }
  return(degrees)
}

#' Auxiliary function to compute rank matrix
#'
#' get outdirect nodes for each node
#'
#' @param K Integer or numeric matrix with two columns, where each row
#'        represents a directed edge \code{(from, to)} in the k-NN graph
#' @param nodes Integer vector of node indices for which to extract outgoing neighbors
#' @return A list where entry \code{out[[i]]} is the vector of neighbors \code{j}
#'         such that there is an edge \code{(i, j)} in \code{K}
Out_direct = function(K,nodes){
  out = list()
  for(i in nodes){
    index = which(K[,1]==i)
    out[[i]] = K[index,2]
  }
  return (out)
}

#' Compute rank matrix
#'
#' This function computes the rank matrix based on the specified graph type and
#' number of neighbors.
#'
#' @param S A numeric matrix representing the similarity matrix
#' @param Dd A dist object representing the distance matrix
#' @param gtype A character string indicating the graph type to use
#'            ("NNG", "MST", or "rNNG")
#' @param k A numeric representing the number of neighbors to use for graph
#' @importFrom ade4 mstree
#' @return A numeric matrix representing the rank matrix
#' @references Zhu, Y., & Chen, H. (2023). A new robust graph for graph-based methods. \emph{arXiv preprint arXiv:2307.15205.}
#' @useDynLib MATES, .registration = TRUE
#' @importFrom Rcpp sourceCpp
rank_mats <- function(S, Dd, gtype, k){
  N = nrow(S)
  if (gtype == "NNG"){
    ### NNG
    R.row = matrix(0,N,N)
    R.row = Rise_Rank(S,method='row')
    R.nng = R.row-N+1+k; R.nng[R.nng<0] = 0
    R.nng = R.nng + t(R.nng)
    return(R.nng)
  } else if (gtype == "rNNG"){
    R.rNNG.u = matrix(0,N,N)
    PKNN = P_Knear_rank(as.matrix(Dd),K=k,lambda = 0.3)
    PKNNG = PKNN$trun_KNN
    R.rNNG.u[cbind(PKNNG[,1],PKNNG[,2])] = 1
    R.rNNG.u[cbind(PKNNG[,2],PKNNG[,1])] = 1
    R.rNNG.dist = R.rNNG.u*as.matrix(Dd)
    R.rNNG.w1 = t(apply(-R.rNNG.dist, 1, rank))
    R.rNNG.w1[R.rNNG.w1>k] = 0
    R.rNNG.w = R.rNNG.w1 + t(R.rNNG.w1)
    return(R.rNNG.w)
  } else if (gtype == "MST"){
    ### MST
    R.mst = matrix(0,N,N)
    for(kk in 1:k){
      gpmst <- mstree(Dd, kk)# minimum spanning tree (number of edges = (N-1)*k)

      R.mst[cbind(gpmst[,1],gpmst[,2])] = R.mst[cbind(gpmst[,1],gpmst[,2])] + 1
      R.mst[cbind(gpmst[,2],gpmst[,1])] = R.mst[cbind(gpmst[,2],gpmst[,1])] + 1
    }
    return(R.mst)
  }
}


#' MATES test statistic with pre-computed view matrices
#'
#' This function takes a list of view matrices (R_list) and other parameters
#' to compute the MATES test statistic.
#'
#' @param UxUy A numeric vector of length 2*S containing the Ux and Uy statistics
#'               for each view
#' @param R_list A list of numeric matrices with length S
#' @param m An integer representing the number of sample in X
#' @param n An integer representing the number of sample in Y
#' @param perm An integer indicating the number of permutation (default is NULL,
#'            which uses closed form)
#' @importFrom stats pchisq
#' @importFrom stats cov
#' @importFrom magrittr %>%
#' @return A list with the MATES test statistic (test.stat) and p-value (pval)
#' @export
#' @examples
#' \donttest{
#' # Generate simulated data
#' set.seed(123)
#' X <- matrix(rnorm(20), ncol = 2)  # 10 samples, 2 dimensions
#' Y <- matrix(rnorm(20), ncol = 2)  # 10 samples, 2 dimensions
#' Z <- rbind(X, Y)
#' m <- nrow(X)
#' n <- nrow(Y)
#' N <- m + n
#'
#' # Compute distance and similarity matrices
#' D <- as.matrix(dist(Z, method = "manhattan"))
#' S <- max(D) - D
#'
#' # Compute rank matrix (simplified NNG approach)
#' R <- matrix(0, N, N)
#' k <- 3
#' for(i in 1:N) {
#'   neighbors <- order(D[i,])[2:(k+1)]  # k nearest neighbors
#'   R[i, neighbors] <- 1:k
#' }
#' R <- R + t(R)
#'
#' # Create list with one rank matrix
#' R_list <- list(R)
#'
#' # Calculate test statistics (Ux and Uy)
#' sample1ID <- 1:m
#' sample2ID <- (m+1):N
#' Ux <- sum(R[sample1ID, sample1ID])
#' Uy <- sum(R[sample2ID, sample2ID])
#' UxUy <- c(Ux, Uy)
#'
#' # Perform MATES test
#' result <- MATES_test(UxUy, R_list, m = m, n = n)
#' print(result$test.stat)
#' print(result$pval)
#' }
#' @useDynLib MATES, .registration = TRUE
#' @importFrom Rcpp sourceCpp
MATES_test <- function(UxUy, R_list, m, n, perm = NULL){
  S = length(R_list)
  N = m + n
  MATES_mat <- diag(1,nrow=2*length(R_list),ncol=2*length(R_list))

  if (is.null(perm)){
    mUxUy.cf = asy_mean(R_list,m,n)
    covUxUy.cf = asy_cov(R_list,m,n)
    UxUy.center.cf <- UxUy - mUxUy.cf
    tryCatch({
      mat_inv.cf <- solve(MATES_mat%*%covUxUy.cf%*%t(MATES_mat))
    }, error = function(e) {
      print(e)
      print(paste0("leads to singular matrix! Use ginv instead of solve."))
      mat_inv.cf <<- MASS::ginv( MATES_mat%*%covUxUy.cf%*%t(MATES_mat) )
    })
    test.stat.MATES.cf <- t(MATES_mat%*%UxUy.center.cf)%*%mat_inv.cf%*%(MATES_mat%*%UxUy.center.cf) %>% c()
    test.MATES.pval.cf <- 1 - pchisq(test.stat.MATES.cf, dim(MATES_mat)[1])
    return(list(test.stat=test.stat.MATES.cf, pval=test.MATES.pval.cf))
  } else {
    BR = sapply(1:perm, function(iperm){
      ID1 = sample(1:N,size=m,replace=F)
      ID2 = setdiff(1:N,ID1)
      UxUy.perm = c()
      for (iiii in 1:S){
        assign(paste0("Ux", iiii, ".perm"), sum(R_list[[iiii]][ID1, ID1]))
        assign(paste0("Uy", iiii, ".perm"), sum(R_list[[iiii]][ID2, ID2]))
        UxUy.perm = c(UxUy.perm, get(paste0("Ux", iiii, ".perm")), get(paste0("Uy", iiii, ".perm")))
      }
      UxUy.perm
    })
    mUxUy.perm = colMeans(t(BR))
    covUxUy.perm = cov(t(BR))
    BR0 = scale(t(BR),center=mUxUy.perm,scale=F)
    UxUy.center.perm <- UxUy - mUxUy.perm
    tryCatch({
      mat_inv.perm <- solve(MATES_mat%*%covUxUy.perm%*%t(MATES_mat))
    }, error = function(e) {
      print(e)
      print(paste0("leads to singular matrix! Use ginv instead of solve."))
      mat_inv.perm <<- MASS::ginv( MATES_mat%*%covUxUy.perm%*%t(MATES_mat) )
    })
    test.stat.MATES <- t(MATES_mat%*%UxUy.center.perm)%*%mat_inv.perm%*%(MATES_mat%*%UxUy.center.perm) %>% c()
    test.stat.perm.MATES <- apply(BR0,1,function(x) t(MATES_mat%*%x)%*%mat_inv.perm%*%(MATES_mat%*%x))
    test.MATES.pval <- mean(test.stat.MATES<test.stat.perm.MATES)
    return(list(test.stat=test.stat.MATES, pval=test.MATES.pval))
  }
}

#' MATES test statistic with two samples (recommended for general use)
#'
#' This function takes two data matrices (m x d and n x d) and other parameters
#' to compute the MATES test statistic. It only implements the same distance,
#' graph, and weight options across all views. For other combinations, please
#' compute the corresponding view matrices (R_list) and use the MATES_stat function
#' directly.
#'
#' @param X A numeric matrix of size m x d
#' @param Y A numeric matrix of size n x d
#' @param S An integer representing the number of moments to use
#' @param dt A character string indicating the distance metric to use
#'                ("manhattan" or "Lp")
#' @param gh A character string indicating the graph type to use
#'             ("NNG", "MST", or "rNNG")
#' @param wt A character string indicating the weight function to use
#'               ("kernel", "rank", "distance", or "plain")
#' @param pow A numeric representing the number of neighbors to use for graph, if
#'            pow = 0, then use default value 10; otherwise use round(N^pow)
#' @param perm An integer indicating the number of permutation (default is NULL,
#'             which uses closed form)
#' @importFrom stats as.dist
#' @importFrom stats median
#' @return A list with the MATES test statistic (test.stat) and p-value (pval)
#' @examples
#' # Generate two-sample data from different distributions
#' set.seed(123)
#' X <- matrix(rnorm(50, mean = 0), ncol = 5)  # 10 samples from N(0,1)
#' Y <- matrix(rnorm(50, mean = 0.5), ncol = 5)  # 10 samples from N(0.5,1)
#'
#' # Perform MATES test
#' result <- MATES(X, Y, S = 4, dt = "manhattan", gh = "NNG", wt = "kernel", pow = 0.8)
#' print(result$test.stat)
#' print(result$pval)
#' @export
#' @useDynLib MATES, .registration = TRUE
#' @importFrom Rcpp sourceCpp
MATES <- function(X, Y, S = 4, dt = "manhattan", gh = "NNG", wt = "kernel", pow = 0.8, perm=NULL){
  # stop if inputs are invalid
  if (!is.matrix(X) | !is.matrix(Y)){
    stop("X and Y must be matrices.")
  }
  if (ncol(X) != ncol(Y)){
    stop("X and Y must have the same number of columns.")
  }
  if (!(dt %in% c("manhattan", "Lp"))){
    stop("distance must be 'manhattan' or 'Lp'.")
  }
  if (!(gh %in% c("NNG", "MST", "rNNG", "rNNGw"))){
    stop("graph must be 'NNG', 'MST', or 'rNNG'.")
  }
  if (!(wt %in% c("kernel", "rank", "distance", "plain"))){
    stop("weight must be 'kernel', 'rank', or 'distance'.")
  }
  if (pow < 0 | pow > 1){
    stop("pow must be between 0 and 1.")
  }
  m = nrow(X)
  n = nrow(Y)
  N = m + n
  sample1ID=1:m; sample2ID=(m+1):N
  Z = rbind(X, Y); row.names(Z) = 1:N
  if (pow == 0){
    KK = 10
  } else {
    KK = min(N-1, round( N^(pow) ))
  }

  if (dt == "manhattan"){
    for (iiii in 1:S){
      assign(paste0("D",iiii), manhattan_dist(Z^iiii) )
    }
  } else if (dt == "Lp"){
    for (iiii in 1:S){
      assign(paste0("D",iiii), minkowski_dist(Z, p=iiii) )
    }
  }
  UxUy = c()
  for (iiii in 1:S){
    Diiii = get(paste0("D",iiii))
    assign(paste0("S",iiii), max(Diiii) - Diiii)
    assign(paste0("Dd",iiii), as.dist(Diiii))
    assign(paste0("sig",iiii), median(Diiii[upper.tri(Diiii)]))
    assign(paste0("Ker",iiii), exp(-Diiii^2/get(paste0("sig",iiii))^2))
    assign(paste0("R",iiii,"_rank"), rank_mats(get(paste0("S",iiii)),get(paste0("Dd",iiii)),gh,KK))
    assign(paste0("R",iiii), matrix(0,N,N))
    R_val <- get(paste0("R", iiii))
    Ker_val <- get(paste0("Ker", iiii))
    Rank_val <- get(paste0("R", iiii, "_rank"))
    if (wt == "distance"){
      R_val[Rank_val > 0] <- get(paste0("S", iiii))[Rank_val > 0]
    } else if (wt == "rank"){
      R_val[Rank_val > 0] <- Rank_val[Rank_val > 0]
    } else if (wt == "kernel"){
      R_val[Rank_val > 0] <- Ker_val[Rank_val > 0]
    } else if (wt == "plain"){
      R_val[Rank_val > 0] <- 1
    }
    assign(paste0("R", iiii), R_val)
    assign(paste0("Ux", iiii), sum(get(paste0("R", iiii))[sample1ID, sample1ID]))
    assign(paste0("Uy", iiii), sum(get(paste0("R", iiii))[sample2ID, sample2ID]))
    UxUy = c(UxUy, get(paste0("Ux", iiii)), get(paste0("Uy", iiii)))
  }
  R_list <- lapply(1:S, function(x) get(paste0("R", x)))
  MATES_test(UxUy, R_list, m, n, perm)
}















