#----------------------------#
# UTILITY / HELPER FUNCTIONS #
#----------------------------#


#' Convert character variables in formula to factor variables.
#' As it's an internal helper function to aid in testing, it is not exported for use outside of the package.
#'
#' @param df A dataframe.
#' @param vars A vector of dataframe column names
#' @return Dataframe where all character variables are now factors
#' 
#' @examples 
#' \dontrun{
#' df <- data.frame(x = 1:3, y = 3:1, z = c("a", "b", "c"),
#'                  stringsAsFactors = FALSE)
#' vars <- names(df)
#' newdf <- df <- chr_2_factor(df, vars)
#' lapply(df, class)
#' }
#'
#' @keywords internal
chr_2_factor <- function(df, vars){
  if (is.data.frame(df) == TRUE ){
    if (is.vector(vars) == TRUE ){
      chars <- sapply(df[vars[-1]], is.character)
      df[vars[-1]][chars] <- data.frame(lapply(df[vars[-1]][chars], as.factor))
      return(df)
    } else {
      stop("Input variable list must be a vector")
    }
  } else {
    stop("Input data must a dataframe")
  }
}


#' Run a logistic or probit model
#' As it's an internal helper function to aid in testing, it is not exported for use outside of the package.
#'
#' @param model_type Use logistic regression ("logistic") or "probit"
#' regression ("probit") to estimate the predicted probability of participating
#' @param match_on Match on estimated propensity score ("pscore") or logit of
#' estimated propensity score ("logit").
#' @param reduced_data Dataframe of reduced treatment and comparison data
#' @param id ID variable in dataset
#' @param treat Treatment variable in dataset
#' @param entry Entry quarter variable in dataset
#' 
#' @examples
#' \dontrun{
#' data(package="rollmatch", "rem_synthdata_small")
#' formula <- as.formula(treat ~ qtr_pmt + age + is_male + is_white +
#'                        is_disabled + is_esrd + months_dual + chron_num + lq_ed +
#'                        yr_ed2 + lq_ip + yr_ip2)
#' vars <- all.vars(formula); treat <- vars[1]
#' tm <- "quarter"; entry <- "entry_q"; id <- "indiv_id"
#' model_type <- "logistic"; match_on <- "logit"
#' reduced_data <- chr_2_factor(rem_synthdata_small, vars)
#' model_output <- runModel(model_type, match_on, rem_synthdata_small, id, treat, entry,
#'               tm, formula)
#' head(model_output)
#' }
#'
#' @return list of the model ($pred_model) and the resulting dataframe with
#' predicting values ($lr_result)
#' @keywords internal
runModel <- function(model_type, match_on, reduced_data, id, treat, entry,
                     tm, fm){
  glm <- "."; binomial <- "."; qlogis <- "."; # var <- "."

  if (model_type %in% c("logistic", "probit")){
    if (model_type == "logistic") {
      link_type <- "logit"
    } else {
      link_type <- "probit"
    }
    pred_model <- glm(fm, data = reduced_data,
                     family = binomial(link = link_type))
  } else {
    stop("model_type must be set to either logistic or probit")
  }

  lr_result <- reduced_data[, c(id, treat, entry, tm)]
  new_obs <- length(pred_model$fitted.values)
  if (dim(lr_result)[1] != new_obs){
    cnames <- colnames(reduced_data)[colSums(is.na(reduced_data)) > 0]
    stop(paste0("Propensity model could not create prediction for all
                observations. Check your data for issues. If any columns had
                NAs, they are printed here: ", cnames))
  }

  if (match_on %in% c("logit", "pscore")){
    if (match_on == "logit") {
      lr_result$score <- qlogis(pred_model$fitted.values)
    } else {
      lr_result$score <- pred_model$fitted.values
    }
  } else {
    stop("match_on must be set to either logit or pscore.")
  }
  output <- list()  #Todo: subclass - see glm and matchit
  output$pred_model <- pred_model
  output$lr_result <- lr_result
  return(output)
}


#' Create a dataframe of comparisons between all treatment and control data.
#' As it's an internal helper function to aid in testing, it is not exported for use outside of the package.
#'
#' @param lr_result The dataset given to runModel with the additional of the
#' score values generated by the propensity score model
#' @param tm The time period indicator
#' @param entry Entry quarter for subject
#' @param id ID variable in dataset
#' 
#' @examples
#' \dontrun{ 
#' load(url(paste0("https://github.com/RTIInternational/rollmatch/raw/master/",
#'                  "tests/testthat/lr_result.rda")))
#' tm <- "quarter"; entry <- "entry_q"; id <- "indiv_id"
#' comparison_pool <- createComparison(lr_result, tm, entry, id)
#' head(comparison_pool)
#' }
#'
#' @return Dataframe comparing all treatment and control data
#' @keywords internal
createComparison <- function(lr_result, tm, entry, id){
  difference <- "."

  comparison_pool <-
    dplyr::inner_join(lr_result[lr_result$treat == 1, ],
                      lr_result[lr_result$treat == 0, ], by = c(tm))

  comparison_pool$difference <- abs(comparison_pool$score.x -
                                      comparison_pool$score.y)
  comparison_pool <-
    dplyr::select(comparison_pool,
                  dplyr::one_of(tm, paste0(id, ".x"), paste0(id, ".y")),
                  paste0(entry, ".x"), paste0(entry, ".y"), difference)
  names(comparison_pool) <- c("time", "treat_id", "control_id",
                              paste0("treat_", entry),
                              paste0("control_", entry), "difference")
  return(comparison_pool)
}

#' Use a caliper to trim the comparison data to only observations within threshold
#'
#' @param alpha The pre-specified distance within which to allow matching.
#' The caliper width is calculated as the \code{alpha} multiplied by the
#' pooled standard deviation of the propensity scores or the logit of the
#' propensity scores - depending on the value of \code{match_on}.
#' @param dataPool Dataframe of comparison data to be trimmed.
#' @param lr_result Dataframe of results from model in runModel.
#' @param weighted_pooled_stdev Boolean. FALSE for average pooled standard
#' deviation and TRUE for weighted pooled standard deviation.
#' 
#' @examples
#' \dontrun{ 
#' load(url(paste0("https://github.com/RTIInternational/rollmatch/raw/master/",
#'                  "tests/testthat/lr_result.rda")))
#' load(url(paste0("https://github.com/RTIInternational/rollmatch/raw/master/",
#'                 "tests/testthat/comparison_pool.rda")))
#' trimmed_pool <- trimPool(alpha = .2, data_pool = comparison_pool,
#'                          lr_result = lr_result)
#' head(trimmed_poll)
#' }
#'                          
#' @return Dataframe of the trimmed comparisons based on the alpha value
#' @keywords internal
trimPool <- function(alpha, data_pool, lr_result,
                     weighted_pooled_stdev = FALSE){
  var <- "."

  if (dim(data_pool)[1] == 0){
    stop("data_pool is empty")
  }
  if (alpha != 0) {
    var_treat <- var(lr_result[(lr_result$treat == 1), "score"])
    var_untreat <- var(lr_result[(lr_result$treat == 0), "score"])

    if (weighted_pooled_stdev == FALSE){
      pooled_stdev <- sqrt( (var_treat + var_untreat) / 2)
    } else {
      pooled_stdev <-
        sqrt(( (nrow(lr_result[(lr_result$treat == 1), ]) - 1) * var_treat +
           (nrow(lr_result[(lr_result$treat == 0), ]) - 1) * var_untreat) /
        (dim(lr_result)[1] - 2))
    }

    width <- alpha * pooled_stdev
    trimmed_data <- dplyr::filter(data_pool, data_pool$difference <= as.numeric(width))
  } else {
    trimmed_data <- data_pool
  }
  trimmed_data <-
    trimmed_data[order(trimmed_data$time,
                      trimmed_data$treat_id, trimmed_data$difference), ]
  return(trimmed_data)
}


#' Algorithm to find best matches from the comparison pool
#' As it's an internal helper function to aid in testing, it is not exported for use outside of the package.
#'
#' @param comparison_pool Dataframe containing the pool from which matches
#' should be found
#' @param num_matches Integer. the number of comparison beneficiary matches to
#' attempt to assign to each treatment beneficiary
#' @param replacement Boolean. Assign comparison beneficiaries with replacement
#' (TRUE) or without replacement (FALSE). If \code{replacement} is TRUE, then
#' comparison beneficiaries will be allowed to be used with replacement within
#' a single quarter, but will be allowed to match to different treatment
#' beneficiaries across multiple quarters.
#' 
#' @examples
#' \dontrun{ 
#' num_matches <- 3; replacement <- TRUE
#' load(url(paste0("https://github.com/RTIInternational/rollmatch/raw/master/",
#'                 "tests/testthat/trimmed_pool.rda")))
#' matches <- createMatches(trimmed_pool, num_matches, replacement)
#' head(matches)
#' }
#'
#' @return Dataframe containing top matches
#' @keywords internal
createMatches <- function(trimmed_pool, num_matches = 3, replacement=TRUE){
  # initialize matches as empty
  matches <- trimmed_pool[0, ]

  # Prematch. Assign controls to treatments that have just one possible match.
  treat_id <- "."; num.controls <- "."; control_id <- "."; time <- "."
  difference <- "."; num.assigned <- "."; reshape <- "."; aggregate <- "."
  just_one <- trimmed_pool %>%
    dplyr::group_by(treat_id) %>%
    dplyr::summarise(num.controls = n()) %>%
    dplyr::filter(num.controls == 1)

  matches <- dplyr::bind_rows(matches, trimmed_pool %>%
                                dplyr::filter(treat_id %in% just_one$treat_id))

  #Remove matched treatments from comparison pool
  trimmed_pool <-
    dplyr::filter(trimmed_pool, !(treat_id %in% just_one$treat_id ))

  count <- 1
  # Loop
  repeat {
    # first_choice is the first entry in the comparison pool for each treat_id
    first_choice <- trimmed_pool[!duplicated(trimmed_pool$treat_id), ]

    if (nrow(first_choice) == 0)
      break

    # Deal with matches that match in more than one quarter
    multi_quarter <- aggregate(time ~ control_id, first_choice,
                               function(x) length(unique(x)))
    multi_quarter <- multi_quarter[multi_quarter$time > 1, ]

    # Initialize empty multicompare data frame
    cnames <- c("time", "treat_id", "control_id", "difference")
    matched_multi_compare <-
      data.frame(matrix(vector(), 0, 4, dimnames = list(c(), cnames)))
    # Todo: How to initialize the empty df for matches/matched_multi_compare

    if (nrow(multi_quarter) != 0) {

      multi_compare <-
        aggregate(difference ~ time + control_id,
                  first_choice[first_choice$control_id %in%
                                 multi_quarter$control_id, ], FUN = mean)

      multi_compare <- multi_compare[order(multi_compare$control_id,
                                           multi_compare$difference), ]
      multi_compare_assigned <-
        multi_compare[!duplicated(multi_compare$control_id), ]

      matched_multi_compare <- # Todo - Change to DPLYR
        merge(multi_compare_assigned, first_choice,
              by = c("control_id", "time", "difference"))
    }

    # Deal with matches in single quarter - these can be assigned directly
    matched_single_compare <- first_choice[!first_choice$control_id %in%
                                             multi_quarter$control_id, ]
    current_matches <- dplyr::bind_rows(matched_multi_compare,
                                        matched_single_compare)

    # Break out of loop if no matches were assigned
    if (nrow(current_matches) == 0)
      break

    matches <- dplyr::bind_rows(matches, current_matches)

    if (nrow(trimmed_pool) > 0) {
      #filter out assigned treatment/match pairs
      trimmed_pool <- dplyr::setdiff(trimmed_pool, current_matches)

      # Keep records where control_id is not in unique(matches$control_id)
      diff_control_id <-
        dplyr::filter(trimmed_pool,
                      !(control_id %in%
                          unique(current_matches$control_id) ))

      if (replacement){
        # if replacement TRUE, keep records where control_id is in
        # matches$control_id and time is the same as the matched time
        keep <-
          dplyr::inner_join(
            trimmed_pool,
            unique(current_matches[, c("time", "control_id")]),
            by = c("control_id", "time"))
      }else{
        keep <- NULL
      }
      #combine the rows to keep and re-sort comparison pool
      trimmed_pool <-
        dplyr::arrange(dplyr::bind_rows(diff_control_id, keep),
                       time, treat_id, difference)

      # If num_matches matches have been assigned, remove the treatments
      # from the comparison pool
      matches_count <- matches[matches$treat_id %in%
                                 unique(current_matches$treat_id), ] %>%
        dplyr::group_by(treat_id) %>%
        dplyr::summarise(num.assigned = n()) %>%
        dplyr::filter(num.assigned == num_matches)
      trimmed_pool <-
        dplyr::filter(trimmed_pool,
                      !(treat_id %in% matches_count$treat_id ))

    } else break  # break out of loop if comparison pool is empty
    count <- count + 1
  }
  return(matches)
}


#' Create additional columns for the matches dataset
#' As it's an internal helper function to aid in testing, it is not exported for use outside of the package.
#'
#' @param Matches Dataframe containing the matches from comparison_pool
#' 
#' @examples
#' \dontrun{ 
#' load(url(paste0("https://github.com/RTIInternational/rollmatch/raw/master/",
#'                 "tests/testthat/matches.rda")))
#' newmatches <- addMatchesColumns(matches)
#' head(newmatches)
#' }
#'
#' @return Dataframe containing top matches
#' @keywords internal
addMatchesColumns <- function(matches){
  ave <- "."
  # Assign a number to the matches.  1st, 2nd 3rd, ...
  matches$match_rank <- ave(1:nrow(matches), matches$treat_id, FUN = seq_along)
  matches <- matches[order(matches$treat_id, matches$match_rank), ]
  # Calculate number of total matches for a given treat_id
  matches$total_matches <- ave(1:nrow(matches), matches$treat_id, FUN = length)
  # Add treatment_weight
  matches$treatment_weight <- 1
  # Calculate the weight of each control
  matches$control_matches <- ave(1:nrow(matches),
                                 matches$control_id, FUN = length)
  matches$row_weight <- 1 / matches$total_matches

  return(matches)
}


#' Create control weights for matches dataset, and final data output
#' As it's an internal helper function to aid in testing, it is not exported for use outside of the package.
#'
#' @param matches Dataframe containing the matches from comparison_pool
#' @param data The original data provided for the function.
#' 
#' @examples
#' \dontrun{ 
#' id <- "indiv_id"
#' data(package="rollmatch", "rem_synthdata")
#' load(url(paste0("https://github.com/RTIInternational/rollmatch/raw/master/",
#'                 "tests/testthat/matches.rda")))
#' data <- rem_synthdata
#' matches <- addMatchesColumns(matches)
#' out_list <- createWeights(matches, data, id)
#' head(a)
#' }
#'
#' @return A list containing two Dataframes. Matches - an updated dataset with
#' control weights added, and data_full an updated version of the original data
#' with weights added.
#' @param id The individual id variable.
#' @keywords internal
createWeights <- function(matches, data, id){
  control_id <- "."; row_weight <- "."; reshape <- "."; aggregate <- ".";
  # Assign weights back to panel dataset
  matches_wide <- reshape(matches, v.names = c("control_id", "difference"),
                          idvar = "treat_id", timevar = "match_rank",
                          direction = "wide",
                          drop = c("row_weight", "control_matches"))
  #Shape weights to apply back to panel dataset
  agg_wt <- matches %>%
    dplyr::group_by(control_id) %>%
    dplyr::summarise(total.weight = sum(row_weight))
  names(agg_wt) <- c("control_id", "control_weight")
  # Add weight to matches dataset
  matches <- merge(matches, agg_wt, by = "control_id")
  # Add in the treat_id with a weight of 1
  names(agg_wt) <- c("id_var", "control_weight")
  agg_wt <- rbind(agg_wt,
                  cbind(id_var = unique(matches$treat_id), control_weight = 1))
  matches_wide <- matches_wide[, grepl("treat_id|control_id",
                                       names(matches_wide))]
  names(matches_wide)[1] <- id
  data_full <- dplyr::left_join(data, matches_wide, by = id)
  data_full$weight <-
    with(agg_wt, control_weight[match(data_full[[id]], id_var)])
  output <- list()
  output$matches <- matches
  output$data_full <- data_full
  return(output)
}


#' Combine the results of rollmatch into a tidy list for output
#' As it's an internal helper function to aid in testing, it is not exported for use outside of the package.
#'
#' @param pred_model The propensity scoring model create in runModel
#' @param lr_result The dataset given to runModel with the additional of the
#' score values generated by the propensity score model
#' @param data_full The original data provided with the additional of
#' control weights
#' @param matches Dataframe containing the matches from comparison_pool
#' @param orig.call The original call of the main function
#' @param formula Original formula used
#' @param tm The time period indicator.
#' @param entry The time period in which the participant enrolled in the
#' intervention (in the same units as the tm variable).
#' @param lookback The number of time periods to look back before the
#' time period of enrollment (1-10).
#' 
#' @examples
#' \dontrun{ 
#' orig.call <- "Ignore"
#' formula <- as.formula(treat ~ qtr_pmt + yr_pmt + age + is_male + is_white +
#'                        is_disabled + is_esrd + months_dual + chron_num + lq_ed +
#'                        yr_ed2 + lq_ip + yr_ip2)
#' tm <- "quarter"; entry <- "entry_q"; lookback <- 1
#' load(url(paste0("https://github.com/RTIInternational/rollmatch/raw/master/",
#'                 "tests/testthat/output.rda")))
#' pred_model <- output$pred_model
#' lr_result <- output$lr_result
#' load(url(paste0("https://github.com/RTIInternational/rollmatch/raw/master/",
#'                 "tests/testthat/out_list.rda")))
#' data_full <- out_list$data_full
#' matches <- out_list$matches
#' out <- makeOutput(pred_model, lr_result, data_full, matches, orig.call,
#'                   formula, tm, entry, lookback)
#' head(out)
#' }
#' 
#' @return \code{output} returns a list containing the following components:
#' \item{call}{The original \code{rollmatch} call.}
#' \item{model}{The output of the model used to estimate the distance measure.}
#' \item{scores}{The propensity score and logit of the propensity score. }
#' \item{data}{The original dataset with matches, scores, and weights applied.}
#' \item{summary}{A basic summary table.}
#' \item{ids_not_matched}{A vector of unmatched treatment ids}
#' \item{total_not_matched}{The count of unmatched treatment ids}
#' @keywords internal
makeOutput <- function(pred_model, lr_result, data_full, matches, orig.call,
                       formula, tm, entry, lookback){
  qlogis <- ".";

  out <- list()  #Todo: subclass - see glm and matchit
  out$call <- orig.call
  out$model <- pred_model
  out$scores <- out$scores <- lr_result[, c(1:4)]
  out$scores$pscore <- pred_model$fitted.values
  out$scores$logit <- qlogis(pred_model$fitted.values)
  out$data <- data_full
  # Number of Rows for Output
  vars <- all.vars(formula)
  treat <- vars[1]
  treat_set <- data_full[data_full[[treat]] == 1 &
                   (data_full[[tm]] == data_full[[entry]] - lookback), ]
  comp_set <- data_full[data_full[[treat]] == 0 &
                  (data_full[[tm]] %in% unique(treat_set[[tm]])), ]

  nn <- matrix(0, ncol = 2, nrow = 3)
  treat_assigned <- length(unique(matches$treat_id))
  control_assigned <- length(unique(matches$control_id))
  nn[1, ] <- c(nrow(comp_set), nrow(treat_set))
  nn[2, ] <- c(control_assigned, treat_assigned)
  nn[3, ] <- c( (nrow(comp_set) - control_assigned),
                (nrow(treat_set) - treat_assigned))
  dimnames(nn) <- list(c("All", "Matched", "Unmatched"),
                       c("Control", "Treated"))
  out$summary <- nn

  all_ids <- unique(lr_result$indiv_id[lr_result$treat == 1])
  discarded <- all_ids[ !(all_ids %in% unique(matches$treat_id))]
  out$ids_not_matched <- discarded
  out$total_not_matched <- length(discarded)
  out$matched_data <- matches

  return(out)
}


#' Add the balancing table to the final output
#' As it's an internal helper function to aid in testing, it is not exported for use outside of the package.
#'
#' @param reduced_data Dataframe of reduced treatment and comparison data
#' @param vars A vector of dataframe column names
#' @param tm The time period indicator
#' @param id ID variable in dataset
#' @param combined_output A list of output for the rollmatch package. 
#' See makeOutput
#' @param treat The Treatment variable
#' @param matches Dataframe containing the matches from comparison_pool
#' 
#' @examples
#' \dontrun{ 
#' fm <- as.formula(treat ~ qtr_pmt + yr_pmt + age + is_male + is_white +
#'                  is_disabled + is_esrd + months_dual + chron_num+lq_ed+
#'                  yr_ed2 + lq_ip + yr_ip2)
#' vars <- all.vars(formula); treat <- vars[1]
#' tm <- "quarter"; id <- "indiv_id"
#' }
#' 
#' @return \code{output} returns a list containing the following additional
#' component to the list out:
#' \item{balance}{The balancing table.}
#' @keywords internal
addBalanceTable <- function(reduced_data, vars, tm, id, combined_output,
                            treat, matches){
  treat_group <- reduced_data[, vars] %>% dplyr::group_by(treat)

  full_summary <-
    cbind(as.data.frame(t(dplyr::summarise_all(treat_group, mean))),
          as.data.frame(t(dplyr::summarise_all(treat_group, "sd"))))
  names(full_summary) <-
    c("Full Comparison Mean", "Full Treatment Mean",
      "Full Comparison Std Dev", "Full Treatment Std Dev")

  ta <- matches[, c("time", "treat_id")]
  ca <- matches[, c("time", "control_id")]
  names(ta) <- c(tm, id)
  names(ca) <- c(tm, id)

  data_assigned <- merge(reduced_data, unique(rbind(ta, ca)))

  treat_group <- data_assigned[, vars] %>% dplyr::group_by(treat)

  matched_summary <-
    cbind(as.data.frame(t(dplyr::summarise_all(treat_group, mean))),
          as.data.frame(t(dplyr::summarise_all(treat_group, "sd"))))

  names(matched_summary) <-
    c("Matched Comparison Mean", "Matched Treatment Mean",
      "Matched Comparison Std Dev", "Matched Treatment Std Dev")

  combined_output$balance <- cbind(full_summary, matched_summary)

  combined_output$balance <-
    combined_output$balance[-1, c("Full Treatment Mean", "Full Comparison Mean",
                      "Full Treatment Std Dev", "Full Comparison Std Dev",
                      "Matched Treatment Mean", "Matched Comparison Mean",
                      "Matched Treatment Std Dev",
                      "Matched Comparison Std Dev")]

  return(combined_output)
}
