#' Fit a Gower-SOM for mixed-attribute data
#'
#' @param data data.frame with correctly typed columns (numeric, factor, etc.).
#' @param grid_rows,grid_cols SOM grid dimensions (rows x cols).
#' @param learning_rate initial learning rate (decays exponentially).
#' @param num_iterations number of iterations.
#' @param radius optional initial neighborhood radius; defaults to max(grid_rows, grid_cols)/2.
#' @param batch_size mini-batch size per iteration.
#' @param sampling logical; if TRUE uses multinomial sampling to update categorical prototypes.
#' @param set_seed integer random seed for reproducibility.
#' @return A list of class `gowersom` with elements `weights` (data.frame) and `coords` (data.frame of grid coords).
#' @export
#' @importFrom StatMatch gower.dist
#' @importFrom dplyr mutate_if
#'
gsom_Training <- function(data, grid_rows = 5, grid_cols = 5,
                          learning_rate = 0.1, num_iterations = 100,
                          radius = NULL, batch_size = 10,
                          sampling = TRUE, set_seed = 123) {

  # data: data.frame with correct variable classification

  # Set seed
  set.seed(set_seed)

  # Parameters for setting gower som

  n_features <- ncol(data)
  n_neurons <- grid_rows * grid_cols
  n_samples <- nrow(data)


  # Weigth initialization

  weights <- data.frame(matrix(0, nrow = n_neurons, ncol = n_features))
  colnames(weights) <- colnames(data)

  for (j in 1:n_features) {
    if (is.numeric(data[[j]])) {
      weights[[j]] <- runif(n_neurons, min = min(data[[j]]), max = max(data[[j]]))
    } else {
      weights[[j]] <- factor(sample(levels(data[[j]]), n_neurons, replace = TRUE), levels = levels(data[[j]]))
    }
  }

  neuron_coords <- expand.grid(row = 1:grid_rows, col = 1:grid_cols)

  # Define the radius for Gower - SOM neurons updates

  if (is.null(radius)) radius <- max(grid_rows, grid_cols) / 2

  for (iter in 1:num_iterations) {
    lr <- learning_rate * exp(-iter / num_iterations)
    current_radius <- radius * exp(-iter / (num_iterations / log(radius)))

    batch_indices <- sample(1:n_samples, batch_size)

    batch <- data[batch_indices, , drop = FALSE]
    #batch <- as.numeric(batch)
    # %>% mutate_if(is.integer, as.numeric)
    #batch <- batch %>% mutate_if(is.integer, as.numeric)
    batch <- dplyr::mutate(batch, dplyr::across(dplyr::where(is.integer), as.numeric))

    gower_matrix <- gower.dist(batch, weights)  # batch_size x n_neurons

    for (i in 1:batch_size) {
      input_vector <- batch[i, , drop = FALSE]
      bmu_index <- which.min(gower_matrix[i, ])
      bmu_coord <- as.numeric(neuron_coords[bmu_index, ])

      coord_matrix <- as.matrix(neuron_coords)
      grid_diffs <- sweep(coord_matrix, 2, bmu_coord)
      dists <- sqrt(rowSums(grid_diffs^2))
      neighbors <- which(dists <= current_radius)
      h_vec <- exp(-dists[neighbors]^2 / (2 * current_radius^2))

      for (k in 1:n_features) {
        if (is.numeric(data[[k]])) {
          weight_vals <- as.numeric(weights[neighbors, k])
          input_val <- as.numeric(input_vector[[k]])
          delta <- input_val - weight_vals

          weights[neighbors, k] <- weight_vals + lr * h_vec * delta

        } else {
          current_vals <- as.character(weights[neighbors, k])
          input_vals <- rep(as.character(input_vector[[k]]), length(h_vec))
          combined_vals <- factor(c(current_vals, input_vals), levels = levels(data[[k]]))
          combined_weights <- c(rep(1, length(current_vals)), h_vec)

          new_cat_value <- gsom_updateCategorical(combined_vals, combined_weights, sampling = sampling)

          new_cat_value <- as.character(new_cat_value)
          weights[neighbors, k] <- factor(rep(new_cat_value, length(neighbors)), levels = levels(data[[k]]))
        }
      }
    }
    message("Iteration: ", iter)
  }

  return(list(
    weights = weights,
    coords = neuron_coords
  ))
}

