## ----setup, include=FALSE-----------------------------------------------------
knitr::opts_chunk$set(
  echo    = TRUE,
  message = FALSE,
  warning = FALSE
)

## ----setup-data---------------------------------------------------------------
library(cdCAT)

# Dataset: K = 3, GDINA model, high-discrimination / low-variability
d <- cdcat_sim[["k3_GDINA_HDLV"]]

Q         <- d$Q          # 70 x 3
alpha     <- d$alpha      # 160 x 3  (true profiles)
responses <- d$responses  # 160 x 70
K         <- ncol(Q)
N         <- nrow(alpha)
J         <- nrow(Q)

# Build the item bank
items <- cdcat_items(Q, "GDINA", gdina_params = d$parameters)
items

## ----fixed-start--------------------------------------------------------------
# Draw one start item, fixed for all conditions and examinees
set.seed(42)
start_item <- sample(J, 1L)
cat("Fixed start item:", start_item, "\n")

## ----conditions---------------------------------------------------------------
conditions <- expand.grid(
  criterion  = c("KL", "PWKL", "MPWKL", "SHE"),
  stop_rule  = c("single", "dual"),
  stringsAsFactors = FALSE
)

## ----simulate-----------------------------------------------------------------
run_condition <- function(criterion, stop_rule) {

  threshold <- if (stop_rule == "single") 0.80 else c(0.70, 0.10)

  results <- vector("list", N)

  for (i in seq_len(N)) {
    session <- CdcatSession$new(
      items,
      criterion  = criterion,
      method     = "MAP",
      prior      = NULL,        # uniform prior
      threshold  = threshold,
      min_items  = 1L,
      max_items  = J,           # effectively no maximum: bounded by bank size
      start_item = start_item
    )

    repeat {
      item <- session$next_item()
      if (item == 0L) break
      session$update(item, responses[i, item])
    }

    res <- session$result()
    results[[i]] <- list(
      alpha_hat = res$alpha_hat,
      n_items   = res$n_items
    )
  }

  results
}

set.seed(2025)
sim_results <- lapply(seq_len(nrow(conditions)), function(ci) {
  run_condition(conditions$criterion[ci], conditions$stop_rule[ci])
})
names(sim_results) <- paste(conditions$criterion, conditions$stop_rule, sep = "_")

## ----compute-metrics----------------------------------------------------------
compute_metrics <- function(results, alpha) {
  N  <- length(results)
  K  <- ncol(alpha)

  alpha_hat_mat <- do.call(rbind, lapply(results, `[[`, "alpha_hat"))
  n_items_vec   <- sapply(results, `[[`, "n_items")

  # PCCR: all K attributes correct
  pccr <- mean(apply(alpha_hat_mat == alpha, 1, all))

  # ACCR per attribute
  accr <- colMeans(alpha_hat_mat == alpha)
  names(accr) <- paste0("ACCR_A", seq_len(K))

  # ATL
  atl <- mean(n_items_vec)

  c(PCCR = round(pccr, 4),
    round(accr, 4),
    ATL   = round(atl, 2))
}

metrics_list <- lapply(sim_results, compute_metrics, alpha = alpha)
metrics_df   <- as.data.frame(do.call(rbind, metrics_list))
metrics_df   <- cbind(conditions, metrics_df)
rownames(metrics_df) <- NULL

## ----table-full---------------------------------------------------------------
knitr::kable(
  metrics_df,
  col.names = c("Criterion", "Stop rule", "PCCR",
                "ACCR A1", "ACCR A2", "ACCR A3", "ATL"),
  caption   = "Classification accuracy and average test length by condition.",
  digits    = 4,
  align     = "llccccc"
)

## ----table-criterion----------------------------------------------------------
agg_crit <- aggregate(
  cbind(PCCR, ATL) ~ criterion,
  data = metrics_df,
  FUN  = mean
)
knitr::kable(
  agg_crit,
  col.names = c("Criterion", "Mean PCCR", "Mean ATL"),
  caption   = "Average PCCR and ATL across stopping rules, by criterion.",
  digits    = 4
)

## ----table-stoprule-----------------------------------------------------------
agg_stop <- aggregate(
  cbind(PCCR, ATL) ~ stop_rule,
  data = metrics_df,
  FUN  = mean
)
knitr::kable(
  agg_stop,
  col.names = c("Stop rule", "Mean PCCR", "Mean ATL"),
  caption   = "Average PCCR and ATL across criteria, by stopping rule.",
  digits    = 4
)

## ----heatmap, fig.width=5, fig.height=3.5, fig.cap="PCCR by criterion and stopping rule."----
pccr_mat <- matrix(
  metrics_df$PCCR,
  nrow     = 4,
  dimnames = list(
    Criterion = c("KL", "PWKL", "MPWKL", "SHE"),
    Stop      = c("single", "dual")
  )
)

image(
  t(pccr_mat),
  col    = hcl.colors(20, "Blues", rev = TRUE),
  xaxt   = "n", yaxt = "n",
  zlim   = c(min(pccr_mat) - 0.02, 1),
  main   = "PCCR",
  xlab   = "Stopping rule",
  ylab   = "Criterion"
)
axis(1, at = c(0, 1), labels = c("single", "dual"))
axis(2, at = seq(0, 1, length.out = 4),
     labels = c("KL", "PWKL", "MPWKL", "SHE"), las = 2)
for (i in 1:2) for (j in 1:4) {
  text((i - 1), (j - 1) / 3,
       labels = sprintf("%.3f", pccr_mat[j, i]),
       cex = 0.85)
}

## ----summary------------------------------------------------------------------
best_pccr  <- metrics_df[which.max(metrics_df$PCCR),  c("criterion","stop_rule","PCCR","ATL")]
best_atl   <- metrics_df[which.min(metrics_df$ATL),   c("criterion","stop_rule","PCCR","ATL")]

knitr::kable(
  rbind(
    cbind(Optimises = "PCCR (max)", best_pccr),
    cbind(Optimises = "ATL (min)",  best_atl)
  ),
  row.names = FALSE,
  caption   = "Best condition per optimisation target."
)

