## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
  collapse  = TRUE,
  comment   = "#>",
  fig.width = 7,
  fig.height = 5
)
set.seed(2025)

## ----install, eval = FALSE----------------------------------------------------
# # From the package source directory:
# devtools::install("swjm")
# 
# # Or from a built tarball:
# install.packages("swjm_0.1.0.tar.gz", repos = NULL, type = "source")

## ----library------------------------------------------------------------------
library(swjm)

## ----gen-jfm------------------------------------------------------------------
set.seed(123)
dat_jfm  <- generate_data(n = 500, p = 10, scenario = 1, model = "jfm")
Data_jfm <- dat_jfm$data

# Preview
head(Data_jfm[, 1:8])

## ----true-jfm-alpha-----------------------------------------------------------
round(dat_jfm$alpha_true, 2)

## ----true-jfm-beta------------------------------------------------------------
round(dat_jfm$beta_true, 2)

## ----gen-jscm-----------------------------------------------------------------
set.seed(456)
dat_jscm  <- generate_data(n = 500, p = 10, scenario = 1, model = "jscm")
Data_jscm <- dat_jscm$data

## ----fit-jfm------------------------------------------------------------------
fit_jfm <- stagewise_fit(
  Data_jfm,
  model   = "jfm",
  penalty = "coop"    # cooperative lasso
)
fit_jfm

## ----path-explore-------------------------------------------------------------
p <- 10
k <- ncol(fit_jfm$alpha)
active_final <- which(fit_jfm$alpha[, k] != 0 |
                      fit_jfm$beta[, k]  != 0)

## ----path-explore-alpha-------------------------------------------------------
round(fit_jfm$alpha[, k], 4)

## ----path-summary-------------------------------------------------------------
summary(fit_jfm)

## ----plot-path, fig.height = 8------------------------------------------------
plot(fit_jfm)

## ----plot-path-re, fig.height = 5---------------------------------------------
plot(fit_jfm, which = "readmission")

## ----cv-jfm-prep--------------------------------------------------------------
lambda_path <- fit_jfm$lambda
dec_idx     <- swjm:::extract_decreasing_indices(lambda_path)
lambda_seq  <- lambda_path[dec_idx]

## ----cv-jfm, cache = TRUE-----------------------------------------------------
set.seed(1)
cv_jfm <- cv_stagewise(
  Data_jfm,
  model      = "jfm",
  penalty    = "coop",
  lambda_seq = lambda_seq,
  K          = 3L
)
cv_jfm

## ----plot-cv------------------------------------------------------------------
plot(cv_jfm)

## ----coef-jfm-alpha-----------------------------------------------------------
round(cv_jfm$alpha[cv_jfm$alpha != 0], 4)

## ----coef-jfm-beta------------------------------------------------------------
round(cv_jfm$beta[cv_jfm$beta != 0], 4)

## ----summary-jfm--------------------------------------------------------------
summary(cv_jfm)

## ----coef-vec-----------------------------------------------------------------
theta_best <- coef(cv_jfm)
length(theta_best)  # 2p

## ----baseline-----------------------------------------------------------------
bh <- baseline_hazard(cv_jfm, times = c(0.5, 1.0, 2.0, 4.0, 6.0))
print(bh)

## ----baseline-re--------------------------------------------------------------
bh_re <- baseline_hazard(cv_jfm, times = seq(0, 5, by = 0.5),
                         which = "readmission")
head(bh_re)

## ----predict-jfm, fig.height = 7----------------------------------------------
set.seed(7)
newz <- matrix(rnorm(30), nrow = 12, ncol = 10)
rownames(newz) <- paste0("Patient_", 1:12)
colnames(newz) <- paste0("x", 1:10)

pred <- predict(cv_jfm, newdata = newz)
pred

## ----pred-survival------------------------------------------------------------
# Survival probabilities for all subjects at first few time points
round(pred$S_re[, 1:5], 3)

## ----plot-pred, fig.height = 8------------------------------------------------
plot(pred, which_subject = 7)

## ----plot-pred-re, fig.height = 5---------------------------------------------
plot(pred, which_subject = 2, which_process = "readmission")

## ----lasso, eval = FALSE------------------------------------------------------
# fit_lasso <- stagewise_fit(Data_jfm, model = "jfm", penalty = "lasso")
# set.seed(2)
# cv_lasso <- cv_stagewise(Data_jfm, model = "jfm", penalty = "lasso", K = 3L)
# summary(cv_lasso)

## ----group, eval = FALSE------------------------------------------------------
# fit_group <- stagewise_fit(Data_jfm, model = "jfm", penalty = "group")
# set.seed(3)
# cv_group <- cv_stagewise(Data_jfm, model = "jfm", penalty = "group", K = 3L)
# summary(cv_group)

## ----fit-jscm-----------------------------------------------------------------
set.seed(456)
dat_jscm  <- generate_data(n = 500, p = 10, scenario = 1, model = "jscm")
Data_jscm <- dat_jscm$data

fit_jscm <- stagewise_fit(Data_jscm, model = "jscm", penalty = "coop")
fit_jscm

## ----cv-jscm, cache = TRUE----------------------------------------------------
lambda_path_jscm <- fit_jscm$lambda
dec_idx_jscm     <- swjm:::extract_decreasing_indices(lambda_path_jscm)
lambda_seq_jscm  <- lambda_path_jscm[dec_idx_jscm]

set.seed(10)
cv_jscm <- cv_stagewise(
  Data_jscm,
  model      = "jscm",
  penalty    = "coop",
  lambda_seq = lambda_seq_jscm,
  K          = 3L
)
cv_jscm

## ----plot-cv-jscm-------------------------------------------------------------
plot(cv_jscm)

## ----summary-jscm-------------------------------------------------------------
summary(cv_jscm)

## ----baseline-jscm------------------------------------------------------------
bh_jscm <- baseline_hazard(cv_jscm, times = c(0.5, 1.0, 2.0, 3.0, 4.0))
print(bh_jscm)

## ----predict-jscm-------------------------------------------------------------
set.seed(7)
newz_jscm <- matrix(runif(30, -1, 1), nrow = 3, ncol = 10)
rownames(newz_jscm) <- paste0("Patient_", 1:3)

pred_jscm <- predict(cv_jscm, newdata = newz_jscm)
pred_jscm

## ----predict-jscm-accel-------------------------------------------------------
round(pred_jscm$time_accel_re, 3)

## ----plot-pred-jscm, fig.height = 8-------------------------------------------
plot(pred_jscm, which_subject = 1)

## ----interpret----------------------------------------------------------------
a <- cv_jfm$alpha
b <- cv_jfm$beta

nz_a <- which(a != 0)
nz_b <- which(b != 0)
shared <- intersect(nz_a, nz_b)

same_sign <- if (length(shared) > 0) shared[sign(a[shared]) == sign(b[shared])] else integer(0)
opp_sign  <- if (length(shared) > 0) shared[sign(a[shared]) != sign(b[shared])] else integer(0)

## ----contrib-example----------------------------------------------------------
c1_re <- pred$contrib_re[1, ]
c1_de <- pred$contrib_de[1, ]

## ----contrib-re---------------------------------------------------------------
round(c1_re[c1_re != 0], 4)

## ----contrib-de---------------------------------------------------------------
round(c1_de[c1_de != 0], 4)

## ----coef-compare-------------------------------------------------------------
p <- 10

show_jfm <- sort(which(dat_jfm$alpha_true != 0 | cv_jfm$alpha != 0 |
                        dat_jfm$beta_true  != 0 | cv_jfm$beta  != 0))

coef_df <- data.frame(
  variable   = paste0("x", show_jfm),
  true_alpha = round(dat_jfm$alpha_true[show_jfm], 3),
  est_alpha  = round(cv_jfm$alpha[show_jfm],       3),
  true_beta  = round(dat_jfm$beta_true[show_jfm],  3),
  est_beta   = round(cv_jfm$beta[show_jfm],        3)
)
colnames(coef_df) <- c("variable", "alpha_true", "alpha_est",
                        "beta_true", "beta_est")
print(coef_df, row.names = FALSE)

## ----coef-compare-jscm--------------------------------------------------------
show_jscm <- sort(which(dat_jscm$alpha_true != 0 | cv_jscm$alpha != 0 |
                         dat_jscm$beta_true  != 0 | cv_jscm$beta  != 0))

coef_jscm <- data.frame(
  variable   = paste0("x", show_jscm),
  true_alpha = round(dat_jscm$alpha_true[show_jscm], 3),
  est_alpha  = round(cv_jscm$alpha[show_jscm],        3),
  true_beta  = round(dat_jscm$beta_true[show_jscm],  3),
  est_beta   = round(cv_jscm$beta[show_jscm],         3)
)
colnames(coef_jscm) <- c("variable", "alpha_true", "alpha_est",
                          "beta_true", "beta_est")
print(coef_jscm, row.names = FALSE)

## ----auc-prep-----------------------------------------------------------------
# Construct competing-risk dataset:
# Keep first readmission (event==1 & t.start==0) + death/censor (event==0).
# Status: 1 = first readmission, 2 = death, 0 = censored.
.cr_data <- function(Data) {
  d3 <- Data[Data$event == 0 | (Data$event == 1 & Data$t.start == 0), ]
  d3 <- d3[order(d3$id, d3$t.start, d3$t.stop), ]
  status <- ifelse(d3$event == 1 & d3$status == 0, 1L,
             ifelse(d3$event == 0 & d3$status == 0, 0L, 2L))
  list(data = d3, status = status)
}

cr_jfm  <- .cr_data(Data_jfm)
cr_jscm <- .cr_data(Data_jscm)

# Baseline covariates (one row per subject)
Z_jfm  <- as.matrix(Data_jfm[!duplicated(Data_jfm$id),   paste0("x", 1:p)])
Z_jscm <- as.matrix(Data_jscm[!duplicated(Data_jscm$id), paste0("x", 1:p)])

# Markers expanded to row level: alpha^T z for readmission, beta^T z for death
M_re_jfm  <- drop(Z_jfm  %*% cv_jfm$alpha)[cr_jfm$data$id]
M_de_jfm  <- drop(Z_jfm  %*% cv_jfm$beta)[cr_jfm$data$id]
M_re_jscm <- drop(Z_jscm %*% cv_jscm$alpha)[cr_jscm$data$id]
M_de_jscm <- drop(Z_jscm %*% cv_jscm$beta)[cr_jscm$data$id]

## ----auc, cache = TRUE--------------------------------------------------------
if (!requireNamespace("timeROC", quietly = TRUE))
  install.packages("timeROC")
library(survival)
library(timeROC)

# Evaluation grid: 20 points spanning the 10th-85th percentile of event times
.tgrid <- function(t_vec, status, n = 20) {
  t_ev <- t_vec[status > 0]
  seq(quantile(t_ev, 0.10), quantile(t_ev, 0.85), length.out = n)
}

t_jfm  <- .tgrid(cr_jfm$data$t.stop,  cr_jfm$status)
t_jscm <- .tgrid(cr_jscm$data$t.stop, cr_jscm$status)

# Readmission AUC: alpha^T z marker, cause = 1
roc_re_jfm <- timeROC(T = cr_jfm$data$t.stop, delta = cr_jfm$status,
                       marker = M_re_jfm, cause = 1, weighting = "marginal",
                       times = t_jfm, ROC = FALSE, iid = FALSE)
roc_re_jscm <- timeROC(T = cr_jscm$data$t.stop, delta = cr_jscm$status,
                        marker = M_re_jscm, cause = 1, weighting = "marginal",
                        times = t_jscm, ROC = FALSE, iid = FALSE)

# Death AUC: beta^T z marker, cause = 2
roc_de_jfm <- timeROC(T = cr_jfm$data$t.stop, delta = cr_jfm$status,
                       marker = M_de_jfm, cause = 2, weighting = "marginal",
                       times = t_jfm, ROC = FALSE, iid = FALSE)
roc_de_jscm <- timeROC(T = cr_jscm$data$t.stop, delta = cr_jscm$status,
                        marker = M_de_jscm, cause = 2, weighting = "marginal",
                        times = t_jscm, ROC = FALSE, iid = FALSE)

## ----auc-plot, fig.height = 5, fig.width = 8----------------------------------
.get_auc <- function(roc, cause) {
  auc <- roc[[paste0("AUC_", cause)]]
  if (is.null(auc)) auc <- roc$AUC
  if (is.null(auc) || !is.numeric(auc)) return(rep(NA_real_, length(roc$times)))
  if (length(auc) == length(roc$times) + 1) auc <- auc[-1]
  as.numeric(auc)
}

old_par <- par(mfrow = c(1, 2), mar = c(4.5, 4, 3, 1))

plot(t_jfm, .get_auc(roc_re_jfm, 1), type = "l", lwd = 2, col = "steelblue",
     xlab = "Time", ylab = "AUC(t)", main = "JFM", ylim = c(0.4, 1))
lines(t_jfm, .get_auc(roc_de_jfm, 2), lwd = 2, col = "tomato", lty = 2)
abline(h = 0.5, lty = 3, col = "grey60")
legend("bottomleft", c("Readmission", "Death"),
       col = c("steelblue", "tomato"), lwd = 2, lty = c(1, 2),
       bty = "n", cex = 0.85)

plot(t_jscm, .get_auc(roc_re_jscm, 1), type = "l", lwd = 2, col = "steelblue",
     xlab = "Time", ylab = "AUC(t)", main = "JSCM", ylim = c(0.4, 1))
lines(t_jscm, .get_auc(roc_de_jscm, 2), lwd = 2, col = "tomato", lty = 2)
abline(h = 0.5, lty = 3, col = "grey60")
legend("bottomleft", c("Readmission", "Death"),
       col = c("steelblue", "tomato"), lwd = 2, lty = c(1, 2),
       bty = "n", cex = 0.85)

par(old_par)

