#!/usr/bin/env Rscript
# Comprehensive Test Suite for SVG Package
# Author: Zaoqu Liu
# This script tests all methods and edge cases

cat("========================================\n")
cat("SVG Package - Comprehensive Test Suite\n")
cat("========================================\n\n")

# Load package
suppressPackageStartupMessages({
    library(devtools)
    load_all(".", quiet = TRUE)
    library(magrittr)
})

# Load test data
load("data/example_svg_data.rda")
expr <- example_svg_data$logcounts
counts <- example_svg_data$counts
coords <- example_svg_data$spatial_coords
truth <- example_svg_data$gene_info$is_svg

cat("Test data:", nrow(expr), "genes,", ncol(expr), "spots\n")
cat("True SVGs:", sum(truth), "\n\n")

errors <- list()
test_count <- 0
pass_count <- 0

# Helper function for testing
run_test <- function(name, expr) {
    test_count <<- test_count + 1
    result <- tryCatch({
        eval(expr)
        TRUE
    }, error = function(e) {
        errors <<- c(errors, list(c(name, conditionMessage(e))))
        FALSE
    })
    if (result) {
        pass_count <<- pass_count + 1
        cat("  [PASS]", name, "\n")
    } else {
        cat("  [FAIL]", name, "\n")
    }
    return(result)
}

# ==============================================================================
# Test 1: CalSVG_MERINGUE
# ==============================================================================
cat("\n", paste(rep("=", 60), collapse = ""), "\n")
cat("Test 1: CalSVG_MERINGUE\n")
cat(paste(rep("=", 60), collapse = ""), "\n")

run_test("MERINGUE - Delaunay network", {
    r <- CalSVG_MERINGUE(expr, coords, network_method = "delaunay", verbose = FALSE)
    stopifnot(all(c("gene", "p.value", "p.adj") %in% names(r)))
    stopifnot(nrow(r) == nrow(expr))
})

run_test("MERINGUE - KNN network (k=6)", {
    r <- CalSVG_MERINGUE(expr, coords, network_method = "knn", k = 6, verbose = FALSE)
    stopifnot(all(c("gene", "p.value", "p.adj") %in% names(r)))
})

run_test("MERINGUE - alternative = two.sided", {
    r <- CalSVG_MERINGUE(expr[1:10,], coords, alternative = "two.sided", verbose = FALSE)
    stopifnot(!is.null(r))
})

run_test("MERINGUE - alternative = less", {
    r <- CalSVG_MERINGUE(expr[1:10,], coords, alternative = "less", verbose = FALSE)
    stopifnot(!is.null(r))
})

run_test("MERINGUE - filter_dist = 5", {
    r <- CalSVG_MERINGUE(expr[1:10,], coords, filter_dist = 5, verbose = FALSE)
    stopifnot(!is.null(r))
})

run_test("MERINGUE - adjust_method = bonferroni", {
    r <- CalSVG_MERINGUE(expr[1:10,], coords, adjust_method = "bonferroni", verbose = FALSE)
    stopifnot(!is.null(r))
})

# ==============================================================================
# Test 2: CalSVG_binSpect
# ==============================================================================
cat("\n", paste(rep("=", 60), collapse = ""), "\n")
cat("Test 2: CalSVG_binSpect\n")
cat(paste(rep("=", 60), collapse = ""), "\n")

run_test("binSpect - kmeans binarization", {
    r <- CalSVG_binSpect(expr, coords, bin_method = "kmeans", verbose = FALSE)
    stopifnot(all(c("gene", "p.value", "p.adj") %in% names(r)))
})

run_test("binSpect - rank binarization (25%)", {
    r <- CalSVG_binSpect(expr, coords, bin_method = "rank", rank_percent = 25, verbose = FALSE)
    stopifnot(!is.null(r))
})

run_test("binSpect - do_fisher_test = FALSE", {
    r <- CalSVG_binSpect(expr[1:20,], coords, do_fisher_test = FALSE, verbose = FALSE)
    stopifnot(!is.null(r))
    stopifnot("estimate" %in% names(r))
})

run_test("binSpect - KNN network", {
    r <- CalSVG_binSpect(expr[1:20,], coords, network_method = "knn", k = 8, verbose = FALSE)
    stopifnot(!is.null(r))
})

# ==============================================================================
# Test 3: CalSVG_SPARKX
# ==============================================================================
cat("\n", paste(rep("=", 60), collapse = ""), "\n")
cat("Test 3: CalSVG_SPARKX\n")
cat(paste(rep("=", 60), collapse = ""), "\n")

run_test("SPARKX - single kernel", {
    r <- suppressWarnings(CalSVG_SPARKX(counts, coords, kernel_option = "single", verbose = FALSE))
    stopifnot(all(c("gene", "p.value", "p.adj") %in% names(r)))
})

run_test("SPARKX - mixture kernels (small test)", {
    r <- suppressWarnings(CalSVG_SPARKX(counts[1:20,], coords, kernel_option = "mixture", verbose = FALSE))
    stopifnot(!is.null(r))
    # Check for kernel-specific columns
    stopifnot(any(grepl("stat_", names(r))))
})

run_test("SPARKX - different adjust method", {
    r <- suppressWarnings(CalSVG_SPARKX(counts[1:20,], coords, adjust_method = "BH", verbose = FALSE))
    stopifnot(!is.null(r))
})

# ==============================================================================
# Test 4: CalSVG_Seurat
# ==============================================================================
cat("\n", paste(rep("=", 60), collapse = ""), "\n")
cat("Test 4: CalSVG_Seurat\n")
cat(paste(rep("=", 60), collapse = ""), "\n")

run_test("Seurat - inverse_squared weights", {
    r <- CalSVG_Seurat(expr, coords, weight_scheme = "inverse_squared", verbose = FALSE)
    stopifnot(all(c("gene", "p.value", "p.adj") %in% names(r)))
})

run_test("Seurat - inverse weights", {
    r <- CalSVG_Seurat(expr[1:20,], coords, weight_scheme = "inverse", verbose = FALSE)
    stopifnot(!is.null(r))
})

run_test("Seurat - gaussian weights", {
    r <- CalSVG_Seurat(expr[1:20,], coords, weight_scheme = "gaussian", verbose = FALSE)
    stopifnot(!is.null(r))
})

run_test("Seurat - custom bandwidth", {
    r <- CalSVG_Seurat(expr[1:20,], coords, weight_scheme = "gaussian", bandwidth = 10, verbose = FALSE)
    stopifnot(!is.null(r))
})

# ==============================================================================
# Test 5: CalSVG_MarkVario (if spatstat available)
# ==============================================================================
cat("\n", paste(rep("=", 60), collapse = ""), "\n")
cat("Test 5: CalSVG_MarkVario\n")
cat(paste(rep("=", 60), collapse = ""), "\n")

if (requireNamespace("spatstat.geom", quietly = TRUE) && 
    requireNamespace("spatstat.explore", quietly = TRUE)) {
    
    run_test("MarkVario - basic test", {
        r <- CalSVG_MarkVario(expr[1:10,], coords, verbose = FALSE)
        stopifnot("r.metric.value" %in% names(r))
        stopifnot("rank" %in% names(r))
    })
    
    run_test("MarkVario - different r_metric", {
        r <- CalSVG_MarkVario(expr[1:10,], coords, r_metric = 10, verbose = FALSE)
        stopifnot(!is.null(r))
    })
} else {
    cat("  [SKIP] spatstat not installed\n")
}

# ==============================================================================
# Test 6: CalSVG_nnSVG (if BRISC available)
# ==============================================================================
cat("\n", paste(rep("=", 60), collapse = ""), "\n")
cat("Test 6: CalSVG_nnSVG\n")
cat(paste(rep("=", 60), collapse = ""), "\n")

if (requireNamespace("BRISC", quietly = TRUE)) {
    run_test("nnSVG - basic test", {
        r <- CalSVG_nnSVG(expr[1:5,], coords, verbose = FALSE)
        stopifnot(all(c("gene", "p.value", "p.adj") %in% names(r)))
        stopifnot("prop_sv" %in% names(r))
    })
    
    run_test("nnSVG - different n_neighbors", {
        r <- CalSVG_nnSVG(expr[1:3,], coords, n_neighbors = 5, verbose = FALSE)
        stopifnot(!is.null(r))
    })
} else {
    cat("  [SKIP] BRISC not installed\n")
}

# ==============================================================================
# Test 7: CalSVG Unified Interface
# ==============================================================================
cat("\n", paste(rep("=", 60), collapse = ""), "\n")
cat("Test 7: CalSVG Unified Interface\n")
cat(paste(rep("=", 60), collapse = ""), "\n")

methods_to_test <- c("meringue", "binspect", "sparkx", "seurat")
for (m in methods_to_test) {
    run_test(paste0("CalSVG - method = ", m), {
        r <- suppressWarnings(CalSVG(expr[1:10,], coords, method = m, verbose = FALSE))
        stopifnot(!is.null(r))
        stopifnot("gene" %in% names(r))
    })
}

# ==============================================================================
# Test 8: Utility Functions
# ==============================================================================
cat("\n", paste(rep("=", 60), collapse = ""), "\n")
cat("Test 8: Utility Functions\n")
cat(paste(rep("=", 60), collapse = ""), "\n")

run_test("buildSpatialNetwork - delaunay", {
    W <- buildSpatialNetwork(coords, method = "delaunay")
    stopifnot(is.matrix(W))
    stopifnot(nrow(W) == ncol(W))
    stopifnot(nrow(W) == nrow(coords))
})

run_test("buildSpatialNetwork - knn", {
    W <- buildSpatialNetwork(coords, method = "knn", k = 8)
    stopifnot(is.matrix(W))
})

run_test("moranI", {
    W <- buildSpatialNetwork(coords, method = "knn", k = 6)
    m <- moranI(expr[1,], W)
    stopifnot(all(c("observed", "expected", "sd") %in% names(m)))
    stopifnot(m$observed >= -1 && m$observed <= 1)
})

run_test("moranI_test", {
    W <- buildSpatialNetwork(coords, method = "knn", k = 6)
    m <- moranI_test(expr[1,], W)
    stopifnot("p.value" %in% names(m))
    stopifnot(m["p.value"] >= 0 && m["p.value"] <= 1)
})

run_test("binarize_expression - kmeans", {
    b <- binarize_expression(expr[1:5,], method = "kmeans")
    stopifnot(all(b %in% c(0, 1)))
})

run_test("binarize_expression - rank", {
    b <- binarize_expression(expr[1:5,], method = "rank", rank_percent = 30)
    stopifnot(all(b %in% c(0, 1)))
})

run_test("ACAT_combine - basic", {
    p <- ACAT_combine(c(0.01, 0.05, 0.1))
    stopifnot(p >= 0 && p <= 1)
})

run_test("ACAT_combine - with weights", {
    p <- ACAT_combine(c(0.01, 0.05, 0.1), weights = c(1, 2, 3))
    stopifnot(p >= 0 && p <= 1)
})

# ==============================================================================
# Test 9: Edge Cases
# ==============================================================================
cat("\n", paste(rep("=", 60), collapse = ""), "\n")
cat("Test 9: Edge Cases\n")
cat(paste(rep("=", 60), collapse = ""), "\n")

run_test("Zero variance gene", {
    expr_zero <- expr
    expr_zero[1,] <- 5  # All same value
    gene_name <- rownames(expr_zero)[1]
    r <- CalSVG_MERINGUE(expr_zero[1:10,], coords, verbose = FALSE)
    # Find by gene name (not row index, since results are sorted)
    gene1_row <- r[r$gene == gene_name, ]
    stopifnot(is.na(gene1_row$p.value))  # Should be NA for zero-var gene
})

run_test("Single gene", {
    r <- CalSVG_MERINGUE(expr[1,,drop=FALSE], coords, verbose = FALSE)
    stopifnot(nrow(r) == 1)
})

run_test("Sample name mismatch - should error", {
    coords_bad <- coords
    rownames(coords_bad) <- paste0("bad_", 1:nrow(coords_bad))
    result <- tryCatch({
        CalSVG_MERINGUE(expr[1:5,], coords_bad, verbose = FALSE)
        FALSE  # Should not reach here
    }, error = function(e) TRUE)
    stopifnot(result)  # Should have errored
})

run_test("Very small expression values", {
    expr_small <- expr * 1e-10
    r <- CalSVG_MERINGUE(expr_small[1:10,], coords, verbose = FALSE)
    stopifnot(!is.null(r))
})

run_test("Large expression values", {
    expr_large <- expr * 1e6
    r <- CalSVG_MERINGUE(expr_large[1:10,], coords, verbose = FALSE)
    stopifnot(!is.null(r))
})

# ==============================================================================
# Test 10: Data Simulation
# ==============================================================================
cat("\n", paste(rep("=", 60), collapse = ""), "\n")
cat("Test 10: Data Simulation\n")
cat(paste(rep("=", 60), collapse = ""), "\n")

run_test("simulate_spatial_data - basic", {
    sim <- simulate_spatial_data(n_spots = 100, n_genes = 50, n_svg = 10)
    stopifnot(nrow(sim$counts) == 50)
    stopifnot(ncol(sim$counts) == 100)
    stopifnot(sum(sim$gene_info$is_svg) == 10)
})

run_test("simulate_spatial_data - hexagonal grid", {
    sim <- simulate_spatial_data(n_spots = 100, grid_type = "hexagonal")
    stopifnot(!is.null(sim$spatial_coords))
})

run_test("simulate_spatial_data - square grid", {
    sim <- simulate_spatial_data(n_spots = 100, grid_type = "square")
    stopifnot(!is.null(sim$spatial_coords))
})

# ==============================================================================
# Summary
# ==============================================================================
cat("\n", paste(rep("=", 60), collapse = ""), "\n")
cat("TEST SUMMARY\n")
cat(paste(rep("=", 60), collapse = ""), "\n")

cat(sprintf("\nTotal tests: %d\n", test_count))
cat(sprintf("Passed: %d\n", pass_count))
cat(sprintf("Failed: %d\n", test_count - pass_count))

if (length(errors) == 0) {
    cat("\n*** ALL TESTS PASSED! ***\n\n")
} else {
    cat("\nFailed tests:\n")
    for (e in errors) {
        cat(sprintf("  - %s: %s\n", e[1], e[2]))
    }
    cat("\n")
}

# Return exit code
quit(status = if(length(errors) == 0) 0 else 1)
