---
title: "ShrinkageTrees: Bayesian Tree Ensembles for Survival Analysis and Causal Inference"
author: "Tijn Jacobs"
date: "`r Sys.Date()`"
output:
  rmarkdown::html_vignette:
    toc: true
    toc_depth: 3
vignette: >
  %\VignetteEncoding{UTF-8}
  %\VignetteIndexEntry{ShrinkageTrees: Introduction and Usage}
  %\VignetteEngine{knitr::rmarkdown}
---

```{r setup, include=FALSE}
knitr::opts_chunk$set(
  collapse  = TRUE,
  comment   = "#>",
  message   = FALSE,
  warning   = FALSE,
  fig.width = 6,
  fig.height = 3.5,
  out.width = "100%"
)
library(ShrinkageTrees)
set.seed(42)
```

## Introduction

**ShrinkageTrees** is an R package that brings Bayesian Additive
Regression Trees (BART; Chipman, George & McCulloch, 2010) to
**survival analysis** and **causal inference**, with a particular focus
on high-dimensional data.

The package implements BART-based models for **right-censored** and
**interval-censored** survival outcomes using an accelerated failure time
(AFT) formulation. Censored event times are handled through Bayesian
data augmentation in the Gibbs sampler, enabling full posterior inference
without proportional-hazards assumptions. For causal inference, the
package provides Bayesian Causal Forests (BCF; Hahn, Murray & Carvalho,
2020), which decompose the outcome into a prognostic function
$\mu(\mathbf{x})$ and a treatment-effect function $\tau(\mathbf{x})$,
each estimated by a separate tree ensemble. This two-forest structure
supports estimation of heterogeneous treatment effects (CATEs) and the
average treatment effect (ATE).

A key feature is the availability of multiple **regularisation strategies**
that can be freely combined within a single model:

- **Classical BART priors** on the tree structure and leaf parameters.
- **Dirichlet splitting priors** (DART; Linero, 2018) for structural
  variable selection.
- **Horseshoe shrinkage on the leaf step heights** (Jacobs, van Wieringen &
  van der Pas, 2025) — a global–local prior that aggressively shrinks
  uninformative leaves toward zero while preserving strong signals.
  This is the main methodological novelty implemented in the package.
- **Half-Cauchy shrinkage** — a lighter-weight alternative that provides
  local shrinkage without a global scale parameter.

### Package map

| Function | Task | Prior |
|---|---|---|
| `HorseTrees()` | Prediction (continuous / binary / survival*) | Horseshoe |
| `ShrinkageTrees()` | Prediction — flexible prior choice* | Horseshoe, DART, BART, … |
| `SurvivalBART()` | Survival prediction* | Classical BART |
| `SurvivalDART()` | Sparse survival prediction* | DART (Dirichlet) |
| `SurvivalBCF()` | Causal survival inference* | BCF (classical) |
| `SurvivalShrinkageBCF()` | Sparse causal survival inference* | BCF + DART |
| `CausalHorseForest()` | Causal inference (all outcomes*) | Horseshoe |
| `CausalShrinkageForest()` | Causal inference — flexible prior* | Horseshoe, DART, BART, … |

\* All survival functions support both right-censored and interval-censored outcomes.

All model-fitting functions return an S3 object with consistent
`print()`, `summary()`, `predict()`, and `plot()` methods.


## Key Concepts

Before diving into examples, we clarify a few concepts that appear throughout
the package interface.

### Outcome types and the `timescale` parameter

Every model-fitting function accepts an `outcome_type` argument:

- `"continuous"` — standard regression (default for most functions).
- `"binary"` — probit BART for binary outcomes (0/1).
- `"right-censored"` — accelerated failure time model for survival data.
  The outcome `y` contains (possibly censored) follow-up times, and the
  `status` vector indicates events (1) vs. censored observations (0).
- `"interval-censored"` — AFT model for interval-censored survival data.
  Instead of `y` and `status`, provide `left_time` and `right_time` vectors
  specifying the lower and upper bounds of the observation window for each
  individual. Three cases are distinguished:
    - **Exact events**: `left_time == right_time` (event observed exactly).
    - **Interval-censored**: `left_time < right_time` with finite
      `right_time` (event occurred somewhere in the interval).
    - **Right-censored**: `right_time = Inf` (event not yet observed).

  This convention follows `survival::Surv(type = "interval2")`.

For survival outcomes, the `timescale` argument controls how the package
treats the times:

- `timescale = "time"` (default): the supplied values are on the
  **original time scale** (positive numbers).
  The package internally applies a log-transform, i.e. models
  $\log(T) = f(\mathbf{x}) + \varepsilon$.
  Predictions from `summary()` and `predict()` are back-transformed
  to the time scale automatically.
- `timescale = "log"`: the supplied values are **already
  log-transformed**. No further transformation is applied.
  Predictions stay on the log scale.

In most applications you should use `timescale = "time"` and pass the
raw survival times directly.

### Shrinkage priors on the step heights

In a BART ensemble each tree contributes a step height (leaf parameter)
to the overall prediction. Classical BART assigns these step heights a
fixed-variance Gaussian prior, which regularises all leaves equally. In
high-dimensional settings, stronger and more adaptive regularisation is
desirable. **ShrinkageTrees** implements two shrinkage priors that are
placed directly on the step heights via a scale mixture of normals:

$$
h_\ell \mid \lambda_\ell, \tau, \omega \sim
\mathcal{N}(0,\; \omega\, \lambda_\ell^2\, \tau^2).
$$

Here $\tau$ is a **global** shrinkage parameter shared across all leaves,
$\lambda_\ell$ is a **local** scale specific to leaf $\ell$, and $\omega$
is a fixed scaling constant. The two currently implemented instantiations
are:

- **Horseshoe** (`prior_type = "horseshoe"`). Both $\lambda_\ell$ and
  $\tau$ receive independent half-Cauchy priors. The heavy tails of the
  half-Cauchy allow individual leaves to escape shrinkage when the data
  support a strong effect, while the global parameter $\tau$ pulls the
  bulk of the estimates toward zero. This is the default prior in
  `HorseTrees()` and `CausalHorseForest()`.
- **Half-Cauchy** (`prior_type = "half-cauchy"`). Only the local scales
  $\lambda_\ell$ receive a half-Cauchy prior; there is no global
  shrinkage parameter. This provides per-leaf adaptivity without the
  additional pooling across the ensemble.

A forest-wide variant of the horseshoe (`prior_type = "horseshoe_fw"`)
shares a single global $\tau$ across all trees in the forest rather than
one per tree. These priors can be selected in `ShrinkageTrees()` and
`CausalShrinkageForest()` via the `prior_type` argument, and can be
combined with DART's Dirichlet splitting prior for simultaneous
structural and parametric regularisation.

### Hyperparameter selection

The two most important hyperparameters are `local_hp` and `global_hp`.
These control the horseshoe prior on the step heights (leaf parameters):

$$
\mu_{jl} \mid \lambda_{jl}, \tau_j \sim \mathcal{N}(0, \lambda_{jl}^2 \tau_j^2),
\qquad \lambda_{jl} \sim \text{C}^+(0, \texttt{local\_hp}),
\qquad \tau_j \sim \text{C}^+(0, \texttt{global\_hp}),
$$

where $\text{C}^+$ denotes the half-Cauchy distribution.  Smaller values
produce stronger shrinkage toward zero; larger values allow more
variation.

**`HorseTrees()` and `CausalHorseForest()`** provide a convenience
parameter `k` that sets both scales automatically:
`local_hp = global_hp = k / sqrt(number_of_trees)`. The default
`k = 0.1` works well in many settings and is a good starting point.

**`ShrinkageTrees()` and `CausalShrinkageForest()`** expose `local_hp`
and `global_hp` directly (no `k`).  A common rule of thumb is:

- `local_hp = k / sqrt(number_of_trees)` with `k` in [0.05, 0.5].
- `global_hp = local_hp` (symmetric) or a larger value if you want
  less overall shrinkage.

The **survival functions** (`SurvivalBART`, `SurvivalDART`) use `k` to
calibrate the standard BART leaf prior:
`local_hp = range(log(y)) / (2 * k * sqrt(number_of_trees))`.
The default `k = 2` follows Chipman et al. (2010).

### The `store_posterior_sample` flag

When `store_posterior_sample = TRUE`, the fitted object stores the full
$N_\text{post} \times n$ matrix of posterior draws for predictions.
This is needed for:

- `predict()` on new data (it re-runs the sampler internally, so
  posterior samples are always produced);
- `plot(fit, type = "ate")` and `plot(fit, type = "cate")`, which
  require the full posterior distribution;
- `plot(fit, type = "survival")` — full posterior credible bands over
  both $\mu_i$ and $\sigma$ (without samples, only sigma-uncertainty
  bands are available);
- any custom posterior analysis (e.g. computing posterior credible
  intervals for individual predictions).

When `FALSE`, only posterior *means* and $\sigma$ draws are stored,
saving memory.  `print()` and `summary()` work in both cases, but
`predict()` will not be available.

### Treatment coding (`treatment_coding`)

All causal model functions — `CausalHorseForest()`,
`CausalShrinkageForest()`, `SurvivalBCF()`, and
`SurvivalShrinkageBCF()` — decompose the outcome as
$$
Y_i = \mu(\mathbf{x}_i) + b_i \cdot \tau(\mathbf{x}_i) + \varepsilon_i,
$$
where $b_i$ is a scalar that depends on the treatment assignment $Z_i$.
The `treatment_coding` argument controls how $b_i$ is defined.
Four options are available:

**`"centered"`** (default).
$b_i = Z_i - 1/2$, so that $b_i \in \{-1/2,\; 1/2\}$.
This is the original BCF parameterisation.

**`"binary"`**.
$b_i = Z_i$, so that $b_i \in \{0,\; 1\}$.
Standard binary coding; the treatment forest captures the full
effect of treatment on the treated.

**`"adaptive"`**.
$b_i = Z_i - \hat{e}(\mathbf{x}_i)$, where $\hat{e}(\mathbf{x}_i)$
is the estimated propensity score. This follows Hahn, Murray &
Carvalho (2020) and is the coding used in the `bcf` R package.
When using this option, a `propensity` vector must be supplied.

**`"invariant"`**.
Parameter-expanded (invariant) treatment coding. The coding
parameters $b_0$ and $b_1$ are assigned $N(0,\; 1/2)$ priors and
estimated within the Gibbs sampler via conjugate normal updates:
$$
Y_i = \mu(\mathbf{x}_i) + b_{Z_i} \cdot \tilde{\tau}(\mathbf{x}_i) + \varepsilon_i,
\qquad b_0,\; b_1 \sim N(0,\; 1/2).
$$
The treatment effect is $\tau(\mathbf{x}_i) = (b_1 - b_0) \cdot
\tilde{\tau}(\mathbf{x}_i)$, and the posterior draws of $b_0$ and $b_1$
are returned in the fitted object.  This parameterisation is invariant
to the coding of the treatment indicator (Hahn et al., 2020,
Section 5.2).

The examples below illustrate each option on a simple continuous-outcome
causal model.

```{r tc-data}
set.seed(50)
n_tc <- 60;  p_tc <- 5
X_tc <- matrix(rnorm(n_tc * p_tc), n_tc, p_tc)
W_tc <- rbinom(n_tc, 1, 0.5)
tau_tc <- 1.5 * (X_tc[, 1] > 0)
y_tc <- X_tc[, 1] + W_tc * tau_tc + rnorm(n_tc, sd = 0.5)
```

```{r tc-centered}
# Centered (default)
fit_tc_cen <- CausalHorseForest(
  y = y_tc,
  X_train_control = X_tc, X_train_treat = X_tc,
  treatment_indicator_train = W_tc,
  treatment_coding = "centered",
  number_of_trees = 5, N_post = 50, N_burn = 25,
  store_posterior_sample = TRUE, verbose = FALSE
)
cat("Centered — ATE:",
    round(mean(fit_tc_cen$train_predictions_treat), 3), "\n")
```

```{r tc-binary}
# Binary
fit_tc_bin <- CausalHorseForest(
  y = y_tc,
  X_train_control = X_tc, X_train_treat = X_tc,
  treatment_indicator_train = W_tc,
  treatment_coding = "binary",
  number_of_trees = 5, N_post = 50, N_burn = 25,
  store_posterior_sample = TRUE, verbose = FALSE
)
cat("Binary — ATE:",
    round(mean(fit_tc_bin$train_predictions_treat), 3), "\n")
```

```{r tc-adaptive}
# Adaptive (requires propensity scores)
ps_tc <- pnorm(0.3 * X_tc[, 1])   # simple propensity model for illustration

fit_tc_ada <- CausalHorseForest(
  y = y_tc,
  X_train_control = X_tc, X_train_treat = X_tc,
  treatment_indicator_train = W_tc,
  treatment_coding = "adaptive",
  propensity = ps_tc,
  number_of_trees = 5, N_post = 50, N_burn = 25,
  store_posterior_sample = TRUE, verbose = FALSE
)
cat("Adaptive — ATE:",
    round(mean(fit_tc_ada$train_predictions_treat), 3), "\n")
```

```{r tc-invariant}
# Invariant (parameter-expanded)
fit_tc_inv <- CausalHorseForest(
  y = y_tc,
  X_train_control = X_tc, X_train_treat = X_tc,
  treatment_indicator_train = W_tc,
  treatment_coding = "invariant",
  number_of_trees = 5, N_post = 50, N_burn = 25,
  store_posterior_sample = TRUE, verbose = FALSE
)
cat("Invariant — ATE:",
    round(mean(fit_tc_inv$train_predictions_treat), 3), "\n")

# Posterior draws of b0 and b1 are stored in the fitted object
cat("b0 posterior mean:", round(mean(fit_tc_inv$b0), 3), "\n")
cat("b1 posterior mean:", round(mean(fit_tc_inv$b1), 3), "\n")
```

The survival functions inherit `treatment_coding` support. For example,
`SurvivalBCF(..., treatment_coding = "invariant")` works out of the box.


## Included Datasets

The package ships with two TCGA datasets for high-dimensional survival
analysis and causal inference:

- **`pdac`** — TCGA pancreatic ductal adenocarcinoma (PAAD) cohort
  (n = 178). A data frame with overall survival times, a binary treatment
  indicator (radiation therapy vs. control), clinical covariates, and
  expression values of ~3,000 genes selected by median absolute deviation.
- **`ovarian`** — TCGA ovarian cancer (OV) cohort (n = 357). A list with
  `X` (357 x 2,000 gene expression matrix, log2-normalised TPM) and
  `clinical` (data frame with OS time/event, age, FIGO stage, tumor grade,
  and treatment: carboplatin vs cisplatin). See `?ovarian` for details.

### The PDAC Dataset

The `pdac` dataset contains overall survival
times, a binary treatment indicator (radiation therapy vs. control), clinical
covariates, and expression values of approximately 3,000 genes selected by
median absolute deviation.

```{r load-data}
library(ShrinkageTrees)
data("pdac")

# Dimensions and column overview
cat("Patients:", nrow(pdac), "\n")
cat("Columns :", ncol(pdac), "\n")
cat("Clinical columns:", paste(names(pdac)[1:13], collapse = ", "), "\n")
cat("Survival: time (months), censoring rate =",
    round(1 - mean(pdac$status), 2), "\n")
cat("Treatment: radiation =", sum(pdac$treatment),
    "/ control =", sum(1 - pdac$treatment), "\n")
```

We separate the outcome, treatment, and covariate matrix for the analyses
below.

```{r prepare-data}
time      <- pdac$time
status    <- pdac$status
treatment <- pdac$treatment
X         <- as.matrix(pdac[, !(names(pdac) %in% c("time", "status", "treatment"))])
```


## Prediction Models

This section demonstrates the single-forest models in **ShrinkageTrees**.
These models estimate a single function $f(\mathbf{x})$ of the covariates,
applicable to continuous, binary, and survival outcomes. We begin with a
binary-outcome example that will also serve as the propensity score model
for the causal analyses later.

### HorseTrees — binary outcome (propensity scores)

Before fitting causal models we estimate propensity scores
$\hat{e}(\mathbf{x}) = P(W=1 \mid \mathbf{x})$ using `HorseTrees()` with a
binary outcome. The probit link is used internally: predictions are on the
latent Gaussian scale and can be converted to probabilities with `pnorm()`.

The code block below uses reduced MCMC settings for illustration.
A real analysis would use `N_post = 5000, N_burn = 5000`.

```{r horsetrees-binary, eval=FALSE}
ps_fit <- HorseTrees(
  y            = treatment,
  X_train      = X,
  outcome_type = "binary",
  k            = 0.1,
  N_post       = 5000,
  N_burn       = 5000,
  verbose      = FALSE
)

propensity <- pnorm(ps_fit$train_predictions)
```

For the remainder of this vignette we use a short synthetic run to keep
build time low.

```{r horsetrees-binary-small}
set.seed(1)
n <- 80;  p <- 10
X_syn  <- matrix(rnorm(n * p), n, p)
W_syn  <- rbinom(n, 1, pnorm(0.8 * X_syn[, 1]))

ps_fit <- HorseTrees(
  y            = W_syn,
  X_train      = X_syn,
  outcome_type = "binary",
  number_of_trees = 5,
  k            = 0.5,
  N_post       = 50,
  N_burn       = 25,
  verbose      = FALSE
)

propensity_syn <- pnorm(ps_fit$train_predictions)
cat("Propensity scores — range: [",
    round(range(propensity_syn), 3), "]\n")
```

### HorseTrees — survival outcome

`HorseTrees()` handles right-censored data via an AFT model.
Pass `outcome_type = "right-censored"` and provide the `status` vector
(1 = event, 0 = censored). When `timescale = "time"` (the default),
the package log-transforms survival times internally and returns
predictions on the log scale (see *Key Concepts* above).

```{r horsetrees-survival}
set.seed(2)
log_T <- X_syn[, 1] + rnorm(n)
C     <- rexp(n, 0.5)
y_syn   <- pmin(exp(log_T), C)
d_syn   <- as.integer(exp(log_T) <= C)

ht_surv <- HorseTrees(
  y               = y_syn,
  status          = d_syn,
  X_train         = X_syn,
  outcome_type    = "right-censored",
  timescale       = "time",
  number_of_trees = 5,
  N_post          = 50,
  N_burn          = 25,
  store_posterior_sample = TRUE,
  verbose         = FALSE
)

cat("Posterior mean log-time (first 5 obs):",
    round(ht_surv$train_predictions[1:5], 3), "\n")
cat("Posterior sigma — mean:",
    round(mean(ht_surv$sigma), 3), "\n")
```

### HorseTrees — interval-censored outcome

When event times are not observed exactly but known to lie within an
interval, the package supports **interval censoring**. Instead of
providing `y` and `status`, pass `left_time` and `right_time` with
`outcome_type = "interval-censored"`.

The three censoring types are encoded as follows:

- **Exact event**: `left_time[i] == right_time[i]`
- **Interval-censored**: `left_time[i] < right_time[i]` (both finite)
- **Right-censored**: `right_time[i] = Inf`

This convention matches `survival::Surv(type = "interval2")`.

```{r horsetrees-ic}
set.seed(20)

# Generate true event times
true_T <- rexp(n, rate = exp(-0.5 * X_syn[, 1]))

# Create interval-censored observations
left_syn  <- true_T * runif(n, 0.5, 1.0)
right_syn <- true_T * runif(n, 1.0, 1.5)

# Mark some as exact observations and some as right-censored
exact_idx <- sample(n, 25)
left_syn[exact_idx]  <- true_T[exact_idx]
right_syn[exact_idx] <- true_T[exact_idx]

rc_idx <- sample(setdiff(seq_len(n), exact_idx), 15)
right_syn[rc_idx] <- Inf

cat("Exact events:", sum(left_syn == right_syn), "\n")
cat("Interval-censored:", sum(left_syn < right_syn & is.finite(right_syn)), "\n")
cat("Right-censored:", sum(!is.finite(right_syn)), "\n")

ht_ic <- HorseTrees(
  left_time       = left_syn,
  right_time      = right_syn,
  X_train         = X_syn,
  outcome_type    = "interval-censored",
  timescale       = "time",
  number_of_trees = 5,
  N_post          = 50,
  N_burn          = 25,
  store_posterior_sample = TRUE,
  verbose         = FALSE
)

cat("Posterior mean log-time (first 5 obs):",
    round(ht_ic$train_predictions[1:5], 3), "\n")
cat("Posterior sigma — mean:",
    round(mean(ht_ic$sigma), 3), "\n")
```

All survival functions (`SurvivalBART`, `SurvivalDART`, `SurvivalBCF`,
`SurvivalShrinkageBCF`) and the general-purpose functions (`ShrinkageTrees`,
`CausalShrinkageForest`, `CausalHorseForest`) accept `left_time` and
`right_time` in the same way. For example, using `SurvivalBART`:

```{r sbart-ic}
set.seed(21)
fit_sbart_ic <- SurvivalBART(
  left_time       = left_syn,
  right_time      = right_syn,
  X_train         = X_syn,
  number_of_trees = 5,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)

cat("SurvivalBART (IC) class:", class(fit_sbart_ic), "\n")
```

### ShrinkageTrees — flexible prior choice

While `HorseTrees()` fixes the prior to the horseshoe, the more general
`ShrinkageTrees()` function exposes the `prior_type` argument, allowing
the user to select among all implemented regularisation strategies.
Available options are `"horseshoe"`, `"horseshoe_fw"` (forest-wide),
`"half-cauchy"`, `"standard"`
(classical BART), and `"dirichlet"` (DART). Below we compare the
per-tree horseshoe and the forest-wide horseshoe on a continuous outcome.

```{r shrinkage-continuous}
set.seed(3)
y_cont <- X_syn[, 1] + 0.5 * X_syn[, 2] + rnorm(n)

# Horseshoe prior (default for HorseTrees)
fit_hs <- ShrinkageTrees(
  y               = y_cont,
  X_train         = X_syn,
  outcome_type    = "continuous",
  prior_type      = "horseshoe",
  local_hp        = 0.1 / sqrt(5),
  global_hp       = 0.1 / sqrt(5),
  number_of_trees = 5,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)

# Forest-wide horseshoe (horseshoe_fw)
fit_fw <- ShrinkageTrees(
  y               = y_cont,
  X_train         = X_syn,
  outcome_type    = "continuous",
  prior_type      = "horseshoe_fw",
  local_hp        = 0.1 / sqrt(5),
  global_hp       = 0.1 / sqrt(5),
  number_of_trees = 5,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)

cat("Horseshoe   — train RMSE:",
    round(sqrt(mean((fit_hs$train_predictions - y_cont)^2)), 3), "\n")
cat("Horseshoe FW— train RMSE:",
    round(sqrt(mean((fit_fw$train_predictions - y_cont)^2)), 3), "\n")
```

### SurvivalBART and SurvivalDART

`SurvivalBART()` and `SurvivalDART()` fit classical BART and DART models
for right-censored survival data under the AFT formulation. They
calibrate prior hyperparameters automatically from the data range,
providing a simple interface when horseshoe shrinkage is not needed.

```{r survival-bart-dart}
set.seed(4)

# SurvivalBART: classical BART prior, AFT likelihood
fit_sbart <- SurvivalBART(
  time            = y_syn,
  status          = d_syn,
  X_train         = X_syn,
  number_of_trees = 5,
  k               = 2.0,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)

# SurvivalDART: Dirichlet (DART) splitting prior
fit_sdart <- SurvivalDART(
  time            = y_syn,
  status          = d_syn,
  X_train         = X_syn,
  number_of_trees = 5,
  k               = 2.0,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)

cat("SurvivalBART  class:", class(fit_sbart), "\n")
cat("SurvivalDART  class:", class(fit_sdart), "\n")
```

## High-Dimensional Survival Analysis

A key motivation for horseshoe shrinkage and the Dirichlet (DART) sparsity
prior is their behaviour in the **$p \gg n$ regime**: many covariates are
available but only a small subset drives the outcome. Classical BART may
struggle here because the standard Gaussian leaf prior is non-sparse and
does not concentrate on a small number of predictors.

We illustrate both priors on a sparse AFT simulation: $n = 60$ observations,
$p = 200$ predictors, and only three active predictors.

```{r hd-data}
set.seed(20)
n_hd <- 60;  p_hd <- 200
X_hd <- matrix(rnorm(n_hd * p_hd), n_hd, p_hd)

# True log-survival depends only on predictors 1, 2, and 3
log_T_hd <- 1.5 * X_hd[, 1] - 1.0 * X_hd[, 2] + 0.5 * X_hd[, 3] + rnorm(n_hd)
C_hd     <- rexp(n_hd, rate = 0.5)
y_hd     <- pmin(exp(log_T_hd), C_hd)
d_hd     <- as.integer(exp(log_T_hd) <= C_hd)

cat("n =", n_hd, "| p =", p_hd,
    "| active predictors = 3",
    "| censoring rate =", round(1 - mean(d_hd), 2), "\n")
```

**ShrinkageTrees (horseshoe)** places global–local shrinkage on the step
heights of every leaf, automatically regularising all 200 predictors
toward zero while preserving the signal in the three active ones.

```{r hd-horseshoe}
set.seed(21)
fit_hd_hs <- ShrinkageTrees(
  y               = y_hd,
  status          = d_hd,
  X_train         = X_hd,
  outcome_type    = "right-censored",
  prior_type      = "horseshoe",
  local_hp        = 0.1 / sqrt(10),
  global_hp       = 0.1 / sqrt(10),
  number_of_trees = 10,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)
```

**SurvivalDART** uses a Dirichlet prior on split probabilities to induce
structural sparsity: after burn-in, most splitting probability is
concentrated on truly predictive variables. Setting `rho_dirichlet = 3`
encodes the prior belief that approximately three predictors are active.

```{r hd-dart}
set.seed(22)
fit_hd_dart <- SurvivalDART(
  time            = y_hd,
  status          = d_hd,
  X_train         = X_hd,
  number_of_trees = 10,
  rho_dirichlet   = 3,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)
```

Both models run without error in the $p > n$ regime. We compare their
posterior mean predictions in log-time against the latent true values
used to generate the data.

```{r hd-compare}
rmse_hs   <- sqrt(mean((fit_hd_hs$train_predictions  - log_T_hd)^2))
rmse_dart <- sqrt(mean((fit_hd_dart$train_predictions - log_T_hd)^2))

cat(sprintf("%-18s  train RMSE (log-time): %.3f\n", "Horseshoe", rmse_hs))
cat(sprintf("%-18s  train RMSE (log-time): %.3f\n", "DART",      rmse_dart))
```

The DART model also produces **variable importance** plots that display
the posterior distribution of each predictor's splitting probability.
With only 50 posterior draws the top-10 plot below should already concentrate
most probability mass near the three truly active predictors.

```{r hd-vi, eval=requireNamespace("ggplot2", quietly=TRUE)}
plot(fit_hd_dart, type = "vi", n_vi = 10)
```

## Causal Forest Models

For causal inference, **ShrinkageTrees** provides Bayesian Causal Forest
(BCF) models that decompose the outcome into a prognostic component and
a treatment effect component:
$$
Y_i = \mu(\mathbf{x}_i) + W_i \cdot \tau(\mathbf{x}_i) + \varepsilon_i,
$$
where $\mu(\cdot)$ is the prognostic (control) function modelled by one
tree ensemble, and $\tau(\cdot)$ is the heterogeneous treatment effect
modelled by a second ensemble. This two-forest structure allows each
component to have its own regularisation, number of trees, and prior —
for instance, a standard BART prior for the prognostic forest and
horseshoe shrinkage for the treatment effect forest.

The package provides four causal model functions with increasing
generality: `SurvivalBCF()` (classical BCF for survival),
`SurvivalShrinkageBCF()` (BCF + DART for survival),
`CausalHorseForest()` (horseshoe BCF for all outcome types), and
`CausalShrinkageForest()` (fully configurable BCF). We illustrate each
below on synthetic data with a known treatment effect.

```{r causal-data}
set.seed(5)
tau_true <- 1.5 * (X_syn[, 1] > 0)    # heterogeneous treatment effect
y_causal <- X_syn[, 1] + W_syn * tau_true + rnorm(n, sd = 0.5)
```

### SurvivalBCF — classical BCF for survival

`SurvivalBCF()` fits a BCF model for right-censored survival outcomes using
classical BART priors.

```{r survbcf, eval=FALSE}
# Full analysis (eval=FALSE — use larger MCMC settings in practice)
fit_sbcf <- SurvivalBCF(
  time       = time,
  status     = status,
  X_train    = X,
  treatment  = treatment,
  propensity = propensity,   # from HorseTrees above
  N_post     = 5000,
  N_burn     = 5000,
  verbose    = FALSE
)
```

```{r survbcf-small}
set.seed(6)
fit_sbcf <- SurvivalBCF(
  time            = y_syn,
  status          = d_syn,
  X_train         = X_syn,
  treatment       = W_syn,
  number_of_trees_control = 5,
  number_of_trees_treat   = 5,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)
cat("SurvivalBCF class:", class(fit_sbcf), "\n")
cat("ATE (posterior mean):",
    round(mean(fit_sbcf$train_predictions_treat), 3), "\n")
```

### SurvivalShrinkageBCF — sparse causal survival forest

`SurvivalShrinkageBCF()` extends BCF with a Dirichlet splitting prior on
both forests, inducing sparsity in high-dimensional settings.

```{r survsbcf-small}
set.seed(7)
fit_ssbcf <- SurvivalShrinkageBCF(
  time            = y_syn,
  status          = d_syn,
  X_train         = X_syn,
  treatment       = W_syn,
  number_of_trees_control = 5,
  number_of_trees_treat   = 5,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)
cat("SurvivalShrinkageBCF class:", class(fit_ssbcf), "\n")
```

### CausalHorseForest — horseshoe causal forest

`CausalHorseForest()` is the primary novel contribution of this package.
It applies horseshoe shrinkage to the leaf parameters of both the prognostic
and treatment-effect forests. This enables effective regularisation when
many covariates are available but few are truly predictive of heterogeneous
treatment effects.

```{r causal-horse}
set.seed(8)
fit_chf <- CausalHorseForest(
  y                         = y_causal,
  X_train_control           = X_syn,
  X_train_treat             = X_syn,
  treatment_indicator_train = W_syn,
  outcome_type              = "continuous",
  number_of_trees           = 5,
  N_post                    = 50,
  N_burn                    = 25,
  store_posterior_sample    = TRUE,
  verbose                   = FALSE
)

cat("CausalHorseForest class:", class(fit_chf), "\n")

# Posterior mean CATE
cate_mean <- fit_chf$train_predictions_treat
cat("CATE — posterior mean (first 5):",
    round(cate_mean[1:5], 3), "\n")

# Posterior ATE
ate_samples <- rowMeans(fit_chf$train_predictions_sample_treat)
cat("ATE posterior mean:",
    round(mean(ate_samples), 3),
    "  95% CI: [",
    round(quantile(ate_samples, 0.025), 3), ",",
    round(quantile(ate_samples, 0.975), 3), "]\n")
```

The fitted object stores the posterior mean CATE for each training
observation in `train_predictions_treat`. When
`store_posterior_sample = TRUE`, the full posterior sample matrix is
available in `train_predictions_sample_treat`, from which the posterior
ATE distribution and credible intervals can be computed as shown above.

You can also supply separate test matrices to obtain out-of-sample
CATE predictions.

```{r causal-horse-test}
set.seed(9)
X_test <- matrix(rnorm(20 * p), 20, p)
W_test <- rbinom(20, 1, 0.5)

fit_chf_test <- CausalHorseForest(
  y                         = y_causal,
  X_train_control           = X_syn,
  X_train_treat             = X_syn,
  treatment_indicator_train = W_syn,
  X_test_control            = X_test,
  X_test_treat              = X_test,
  treatment_indicator_test  = W_test,
  outcome_type              = "continuous",
  number_of_trees           = 5,
  N_post                    = 50,
  N_burn                    = 25,
  store_posterior_sample    = TRUE,
  verbose                   = FALSE
)

cat("Test CATE (first 5):",
    round(fit_chf_test$test_predictions_treat[1:5], 3), "\n")
```

### CausalShrinkageForest — flexible causal priors

`CausalShrinkageForest()` is the most general causal model interface.
It allows independent prior choices for the prognostic and treatment
effect forests via `prior_type_control` and `prior_type_treat`. For
example, one could use a standard BART prior for the prognostic forest
(where variable selection is less critical) and horseshoe shrinkage for
the treatment forest (where most covariates are expected to be
irrelevant for the treatment effect).

```{r causal-shrinkage}
set.seed(10)
lh <- 0.1 / sqrt(5)

fit_csf <- CausalShrinkageForest(
  y                         = y_causal,
  X_train_control           = X_syn,
  X_train_treat             = X_syn,
  treatment_indicator_train = W_syn,
  outcome_type              = "continuous",
  prior_type_control        = "horseshoe",
  prior_type_treat          = "horseshoe",
  local_hp_control          = lh,
  global_hp_control         = lh,
  local_hp_treat            = lh,
  global_hp_treat           = lh,
  number_of_trees_control   = 5,
  number_of_trees_treat     = 5,
  N_post                    = 50,
  N_burn                    = 25,
  store_posterior_sample    = TRUE,
  verbose                   = FALSE
)

cat("CausalShrinkageForest class:", class(fit_csf), "\n")
cat("Acceptance ratio (control):",
    round(fit_csf$acceptance_ratio_control, 3), "\n")
cat("Acceptance ratio (treat)  :",
    round(fit_csf$acceptance_ratio_treat, 3), "\n")
```

The `horseshoe_fw` prior adds a forest-wide shrinkage parameter that is
tracked in the fitted object.

```{r causal-shrinkage-fw}
set.seed(11)
fit_fw2 <- CausalShrinkageForest(
  y                         = y_causal,
  X_train_control           = X_syn,
  X_train_treat             = X_syn,
  treatment_indicator_train = W_syn,
  outcome_type              = "continuous",
  prior_type_control        = "horseshoe_fw",
  prior_type_treat          = "horseshoe_fw",
  local_hp_control          = lh,
  global_hp_control         = lh,
  local_hp_treat            = lh,
  global_hp_treat           = lh,
  number_of_trees_control   = 5,
  number_of_trees_treat     = 5,
  N_post                    = 50,
  N_burn                    = 25,
  verbose                   = FALSE
)

cat("Forest-wide shrinkage (control, first 5 draws):\n")
print(round(fit_fw2$forestwide_shrinkage_control[1:5], 4))
```

## S3 Methods

All fitted objects — whether from prediction models or causal models —
support a consistent set of S3 methods: `print()`, `summary()`,
`predict()`, and `plot()`. This section illustrates each method using
the models fitted above.

### print()

Calling `print()` (or just typing the object name) displays a concise model
summary.

```{r print}
print(fit_chf)
```

For causal models the output additionally shows the number of trees in each
forest and prior details for both components.

```{r print-csf}
print(fit_csf)
```

### summary()

`summary()` returns a structured list and displays a richer description
including posterior statistics for $\sigma$, acceptance ratios, and
treatment effect estimates.

```{r summary-shrinkage}
smry <- summary(fit_hs)
print(smry)
```

For causal models the summary includes the posterior ATE with a 95% credible
interval (when `store_posterior_sample = TRUE`).

```{r summary-causal}
smry_c <- summary(fit_chf)
print(smry_c)

# Access the ATE directly
cat("ATE mean  :", round(smry_c$treatment_effect$ate, 3), "\n")
cat("ATE 95% CI: [",
    round(smry_c$treatment_effect$ate_lower, 3), ",",
    round(smry_c$treatment_effect$ate_upper, 3), "]\n")
```

##### Population vs. mixed ATE

By default the ATE credible interval is obtained by a **Bayesian
bootstrap**: at each MCMC iteration $s$ the observation-level CATEs
$\tau^{(s)}(x_i)$ are reweighted with Dirichlet(1, ..., 1) weights,

$$
\widehat{\mathrm{PATE}}^{(s)} \;=\; \sum_{i=1}^n w_i^{(s)}\, \tau^{(s)}(x_i),
\qquad (w_1^{(s)}, \dots, w_n^{(s)}) \sim \mathrm{Dir}(1, \dots, 1).
$$

The collection $\{\widehat{\mathrm{PATE}}^{(s)}\}$ approximates the
posterior of the **population ATE** and therefore propagates uncertainty
in both $\tau(\cdot)$ and the covariate distribution $F_X$. Setting
`bayesian_bootstrap = FALSE` reverts to equal $1/n$ weights, giving the
**mixed ATE** (MATE) that conditions on the observed covariates and has a
narrower credible interval.

```{r summary-pate-mate}
smry_pate <- summary(fit_chf, bayesian_bootstrap = TRUE)   # default
smry_mate <- summary(fit_chf, bayesian_bootstrap = FALSE)
```

The standalone helper `bayesian_bootstrap_ate()` returns both posteriors
and their draws in a single list, and also works on a
`CausalShrinkageForestPrediction` returned by `predict()` so that the
PATE integrates over a prespecified target population.

```{r bb-ate}
bb <- bayesian_bootstrap_ate(fit_chf)
cat("PATE:", round(bb$pate_mean, 3),
    " 95% CI: [", round(bb$pate_ci$lower, 3), ",",
                  round(bb$pate_ci$upper, 3), "]\n")
cat("MATE:", round(bb$mate_mean, 3),
    " 95% CI: [", round(bb$mate_ci$lower, 3), ",",
                  round(bb$mate_ci$upper, 3), "]\n")
```

### predict()

`predict()` computes the posterior predictive distribution on new data.
It returns a `ShrinkageTreesPrediction` object with posterior mean and
credible-interval vectors.

```{r predict}
X_new  <- matrix(rnorm(10 * p), 10, p)

pred <- predict(fit_hs, newdata = X_new)
print(pred)
```

```{r predict-ci}
# Point estimates and 95% credible intervals
head(data.frame(
  mean  = round(pred$mean,  3),
  lower = round(pred$lower, 3),
  upper = round(pred$upper, 3)
))
```

#### Causal predictions

For causal models (`CausalShrinkageForest` and `CausalHorseForest`),
`predict()` returns three sets of posterior summaries:

- **prognostic**: the control-forest prediction $\mu(\mathbf{x})$ — the
  expected outcome under control.
- **cate**: the Conditional Average Treatment Effect $\tau(\mathbf{x})$ —
  the additional effect of treatment for each individual.
- **total**: the combined prediction $\mu(\mathbf{x}) + \tau(\mathbf{x})$
  — the expected outcome under treatment.

For survival models with `timescale = "time"`, the prognostic and total
components are back-transformed to the original time scale (via
$\exp(\cdot)$), and the CATE becomes a **multiplicative time ratio**:
$\exp(\tau) > 1$ means treatment prolongs survival.

The `predict()` method requires two covariate matrices — one for each
forest — matching the columns used at fit time.

```{r predict-causal}
X_new_ctrl  <- matrix(rnorm(10 * p), 10, p)
X_new_treat <- matrix(rnorm(10 * p), 10, p)

pred_c <- predict(fit_chf, newdata_control = X_new_ctrl,
                  newdata_treat = X_new_treat)
print(pred_c)
```

```{r predict-causal-detail}
# Extract individual components
head(data.frame(
  prognostic = round(pred_c$prognostic$mean, 3),
  cate       = round(pred_c$cate$mean, 3),
  total      = round(pred_c$total$mean, 3)
))
```

### plot()

The `plot()` method produces diagnostic and inferential graphics using the
**ggplot2** package (a suggested dependency).

#### Sigma traceplot

```{r plot-trace, eval=requireNamespace("ggplot2", quietly=TRUE)}
plot(fit_hs, type = "trace")
```

#### Posterior ATE distribution (causal models)

The ATE density uses the Bayesian-bootstrap PATE posterior by default;
pass `bayesian_bootstrap = FALSE` to plot the (narrower) mixed ATE
density instead. See the summary section above for the definitions.

```{r plot-ate, eval=requireNamespace("ggplot2", quietly=TRUE)}
plot(fit_chf, type = "ate")                         # PATE (default)
plot(fit_chf, type = "ate", bayesian_bootstrap = FALSE)  # MATE
```

#### CATE caterpillar plot

```{r plot-cate, eval=requireNamespace("ggplot2", quietly=TRUE)}
plot(fit_chf, type = "cate")
```

#### Variable importance (Dirichlet prior)

Variable importance plots are available when `prior_type = "dirichlet"`.
For causal models with `prior_type_control = "dirichlet"` or
`prior_type_treat = "dirichlet"`, use `forest = "control"`, `"treat"`,
or `"both"`.

```{r vi-fit}
set.seed(12)
fit_dart <- ShrinkageTrees(
  y               = y_cont,
  X_train         = X_syn,
  outcome_type    = "continuous",
  prior_type      = "dirichlet",
  local_hp        = 0.1 / sqrt(5),
  number_of_trees = 5,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)
```

```{r plot-vi, eval=requireNamespace("ggplot2", quietly=TRUE)}
plot(fit_dart, type = "vi", n_vi = 10)
```

#### Survival curves

For survival models (`outcome_type = "right-censored"` or
`"interval-censored"`), the `plot()` method can draw posterior survival
curves derived from the fitted AFT log-normal model:
$$
S(t \mid \mathbf{x}_i) = 1 - \Phi\!\left(\frac{\log t - \mu_i}{\sigma}\right),
$$
where $\mu_i = f(\mathbf{x}_i)$ is the BART ensemble prediction and $\sigma$
is the residual standard deviation on the log-time scale.

The `type = "survival"` option supports two modes controlled by the `obs`
argument:

- **Population-averaged curve** (`obs = NULL`, the default): computes
  $\bar{S}(t) = n^{-1}\sum_i S(t \mid \mathbf{x}_i)$ at each MCMC
  iteration, giving credible bands that reflect full posterior uncertainty.
- **Individual curves** (`obs = c(1, 5, ...)`): one curve per selected
  training observation with its own credible band.

Additional options:

| Argument | Description |
|---|---|
| `level` | Width of the pointwise credible band (default `0.95`). |
| `t_grid` | Custom time grid (original scale). Auto-generated if `NULL`. |
| `km` | If `TRUE`, overlay the Kaplan–Meier estimate (population-average only). |

We use the survival fit from the earlier section:

```{r surv-curve-pop, eval=requireNamespace("ggplot2", quietly=TRUE)}
# Population-averaged survival curve with 95% credible band
plot(ht_surv, type = "survival")
```

```{r surv-curve-km, eval=requireNamespace("ggplot2", quietly=TRUE) && requireNamespace("survival", quietly=TRUE)}
# Same curve with the Kaplan-Meier estimate overlaid for comparison
plot(ht_surv, type = "survival", km = TRUE)
```

```{r surv-curve-ind, eval=requireNamespace("ggplot2", quietly=TRUE)}
# Individual survival curves for observations 1, 20, 40, 60, and 80
plot(ht_surv, type = "survival", obs = c(1, 20, 40, 60, 80))
```

```{r surv-curve-single, eval=requireNamespace("ggplot2", quietly=TRUE)}
# Single individual with a narrower 90% credible band
plot(ht_surv, type = "survival", obs = 1, level = 0.90)
```

When `store_posterior_sample = FALSE`, the credible bands only reflect
uncertainty in $\sigma$ (using plug-in posterior mean $\hat{\mu}_i$).
The survival functions (`SurvivalBART`, `SurvivalDART`, etc.) store
posterior samples by default, so full posterior bands are available
out of the box.

#### Posterior predictive survival curves

The survival curves above are based on the **training** data — they show
$S(t \mid \mathbf{x}_i)$ for the observations used to fit the model.
For **new** (out-of-sample) data, call `predict()` first and then
`plot()` on the prediction object.  This produces *posterior
predictive* survival curves that propagate full parameter uncertainty
through to the new covariate values:

```{r surv-pred-setup}
# New observations for prediction
set.seed(99)
X_new <- matrix(rnorm(20 * p), ncol = p)
pred_surv <- predict(ht_surv, newdata = X_new)
```

```{r surv-pred-pop, eval=requireNamespace("ggplot2", quietly=TRUE)}
# Population-averaged posterior predictive survival curve
plot(pred_surv, type = "survival")
```

```{r surv-pred-ind, eval=requireNamespace("ggplot2", quietly=TRUE)}
# Individual posterior predictive curves for selected new observations
plot(pred_surv, type = "survival", obs = c(1, 5, 10))
```

The same `level` and `t_grid` arguments are available as for the
training-data survival curves.  The Kaplan–Meier overlay (`km = TRUE`)
is not available for prediction objects, since observed event times are
only known for the training set.

## Multi-Chain MCMC

Running multiple independent chains improves mixing diagnostics and reduces
sensitivity to starting values. Pass `n_chains > 1` to any model-fitting
function; chains are run in parallel via `parallel::mclapply` on Unix-like
systems.

```{r multi-chain}
set.seed(13)
fit_2chain <- ShrinkageTrees(
  y               = y_cont,
  X_train         = X_syn,
  outcome_type    = "continuous",
  prior_type      = "horseshoe",
  local_hp        = 0.1 / sqrt(5),
  global_hp       = 0.1 / sqrt(5),
  number_of_trees = 5,
  N_post          = 50,
  N_burn          = 25,
  n_chains        = 2,
  verbose         = FALSE
)

cat("n_chains stored  :", fit_2chain$mcmc$n_chains, "\n")
cat("Total sigma draws:", length(fit_2chain$sigma),
    " (2 chains x 50 draws)\n")
cat("Per-chain acceptance ratios:\n")
print(round(fit_2chain$chains$acceptance_ratios, 3))
```

The same interface works for causal models.

```{r multi-chain-causal}
set.seed(14)
fit_causal_2chain <- CausalShrinkageForest(
  y                         = y_causal,
  X_train_control           = X_syn,
  X_train_treat             = X_syn,
  treatment_indicator_train = W_syn,
  outcome_type              = "continuous",
  prior_type_control        = "horseshoe",
  prior_type_treat          = "horseshoe",
  local_hp_control          = lh,
  global_hp_control         = lh,
  local_hp_treat            = lh,
  global_hp_treat           = lh,
  number_of_trees_control   = 5,
  number_of_trees_treat     = 5,
  N_post                    = 50,
  N_burn                    = 25,
  n_chains                  = 2,
  verbose                   = FALSE
)

cat("Pooled sigma draws:", length(fit_causal_2chain$sigma), "\n")
cat("Per-chain acceptance ratios (control):\n")
print(round(fit_causal_2chain$chains$acceptance_ratios_control, 3))
```

With multiple chains the traceplot shows one line per chain, and the
overlaid density plot compares the marginal posterior of $\sigma$ across
chains — both require `n_chains > 1`.

```{r plot-trace-2chain, eval=requireNamespace("ggplot2", quietly=TRUE)}
plot(fit_2chain, type = "trace")
```

```{r plot-density-2chain, eval=requireNamespace("ggplot2", quietly=TRUE)}
plot(fit_2chain, type = "density")
```

## Convergence Diagnostics

MCMC methods require careful assessment of convergence before trusting
the posterior summaries. Here are practical guidelines for
**ShrinkageTrees** models.

### Sigma traceplot

The traceplot of the error standard deviation $\sigma$ (via
`plot(fit, type = "trace")`) is the primary diagnostic. A well-mixing
chain should show:

- **No trend**: the trace should fluctuate around a stable level after
  burn-in. A persistent upward or downward drift indicates the sampler
  has not yet converged — increase `N_burn`.
- **Good mixing**: the chain should move freely across its stationary
  range. If the trace is "sticky" (stays in the same region for long
  stretches), this indicates poor mixing.
- **Chain agreement** (when `n_chains > 1`): separate chains should
  overlap substantially. The density overlay
  (`plot(fit, type = "density")`) makes this easy to check visually.

### Acceptance ratio

The `summary()` output reports the average Metropolis–Hastings acceptance
ratio for the tree structure proposals (grow/prune moves). As a rough
guide:

- **0.15–0.50** is typical and healthy for tree-based MCMC.
- **Very low** (< 0.05) means most proposals are rejected. The sampler
  is barely exploring tree space. Consider increasing `N_burn` and
  `N_post`, or relaxing the tree structure prior (lower `power`, higher
  `base`).
- **Very high** (> 0.70) means almost all proposals are accepted, which
  typically indicates the trees are staying very small (e.g. stumps).
  This is less concerning but may limit the model's expressiveness.

### Formal diagnostics with coda

When the suggested package **coda** is installed, `summary()` automatically
reports **effective sample size** (ESS) and — for multi-chain fits — the
**Gelman–Rubin $\hat{R}$**.

```{r coda-summary, eval=requireNamespace("coda", quietly=TRUE)}
# summary() includes convergence diagnostics when coda is available
summary(fit_2chain)
```

For more detailed diagnostics, convert the fitted object to a
`coda::mcmc.list` with `as.mcmc.list()`:

```{r coda-diagnostics, eval=requireNamespace("coda", quietly=TRUE)}
library(coda)
mcmc_obj <- as.mcmc.list(fit_2chain)

# Gelman-Rubin R-hat (values near 1 indicate convergence)
coda::gelman.diag(mcmc_obj)

# Effective sample size
coda::effectiveSize(mcmc_obj)

# Geweke diagnostic (per chain)
coda::geweke.diag(mcmc_obj[[1]])
```

The returned `mcmc.list` object is compatible with all **coda** functions,
including `coda::autocorr.plot()`, `coda::gelman.plot()`,
`coda::heidel.diag()`, and `coda::raftery.diag()`.

### Recommended MCMC settings

The examples in this vignette use very small `N_post` and `N_burn` to
keep build time low. For a real analysis:

- **Minimum**: `N_post = 2000, N_burn = 2000`.
- **Recommended**: `N_post = 5000, N_burn = 5000`.
- **High-dimensional or survival**: `N_post = 5000, N_burn = 10000`
  (the AFT data augmentation step can slow mixing, so a longer burn-in
  helps).
- **Multiple chains**: `n_chains = 2` or `4` to verify convergence and
  produce pooled posterior samples.

## Case Study: TCGA PAAD (Full Analysis)

The full analysis of the `pdac` dataset replicates the case study from
Jacobs, van Wieringen & van der Pas (2025). Due to the high-dimensional
covariate space (~3,000 genes) and the large MCMC settings needed for
reliable inference, the code below is provided for reference but is not
evaluated during vignette building. Pre-computed results can be reproduced
by running the `pdac_analysis` demo: `demo("pdac_analysis", package = "ShrinkageTrees")`.

### Step 1: Propensity score estimation

```{r pdac-ps, eval=FALSE}
data("pdac")

time      <- pdac$time
status    <- pdac$status
treatment <- pdac$treatment
X         <- as.matrix(pdac[, !(names(pdac) %in% c("time","status","treatment"))])

set.seed(2025)
ps_fit <- HorseTrees(
  y            = treatment,
  X_train      = X,
  outcome_type = "binary",
  k            = 0.1,
  N_post       = 5000,
  N_burn       = 5000,
  verbose      = FALSE
)

propensity <- pnorm(ps_fit$train_predictions)
```

```{r pdac-ps-overlap, eval=FALSE}
# Overlap plot
p0 <- propensity[treatment == 0]
p1 <- propensity[treatment == 1]

hist(p0, breaks = 15, col = rgb(1, 0.5, 0, 0.5), xlim = range(propensity),
     xlab = "Propensity score", main = "Propensity score overlap")
hist(p1, breaks = 15, col = rgb(0, 0.5, 0, 0.5), add = TRUE)
legend("topright", legend = c("Control", "Treated"),
       fill = c(rgb(1,0.5,0,0.5), rgb(0,0.5,0,0.5)))
```

### Step 2: Causal survival forest

```{r pdac-causal, eval=FALSE}
# Augment control matrix with propensity scores (BCF-style)
X_control <- cbind(propensity, X)

# Log-transform and centre survival times
log_time <- log(time) - mean(log(time))

set.seed(2025)
fit_pdac <- CausalHorseForest(
  y                         = log_time,
  status                    = status,
  X_train_control           = X_control,
  X_train_treat             = X,
  treatment_indicator_train = treatment,
  outcome_type              = "right-censored",
  timescale                 = "log",
  number_of_trees           = 200,
  N_post                    = 5000,
  N_burn                    = 5000,
  store_posterior_sample    = TRUE,
  verbose                   = FALSE
)
```

### Step 3: ATE and CATE estimation

```{r pdac-ate, eval=FALSE}
# Print model summary
print(fit_pdac)
smry_pdac <- summary(fit_pdac)
print(smry_pdac)

# ATE
cat("ATE posterior mean:",
    round(smry_pdac$treatment_effect$ate, 3), "\n")
cat("95% CI: [",
    round(smry_pdac$treatment_effect$ate_lower, 3), ",",
    round(smry_pdac$treatment_effect$ate_upper, 3), "]\n")
```

### Step 4: Diagnostics

```{r pdac-diag, eval=FALSE}
# Sigma convergence
plot(fit_pdac, type = "trace")

# Posterior ATE density
plot(fit_pdac, type = "ate")

# CATE caterpillar plot
plot(fit_pdac, type = "cate")
```



## References

Chipman, H. A., George, E. I., & McCulloch, R. E. (2010). Bayesian Additive
Regression Trees. *Annals of Applied Statistics*, 4(1), 266–298.

Hahn, P. R., Murray, J. S., & Carvalho, C. M. (2020). Bayesian Regression
Tree Models for Causal Inference: Regularization, Confounding, and
Heterogeneous Treatment Effects. *Bayesian Analysis*, 15(3), 965–1056.

Jacobs, T., van Wieringen, W. N., & van der Pas, S. L. (2025). Horseshoe
Forests for High-Dimensional Causal Survival Analysis. *arXiv preprint*
arXiv:2507.22004.

Linero, A. R. (2018). Bayesian Regression Trees for High-Dimensional
Prediction and Variable Selection. *Journal of the American Statistical
Association*, 113(522), 626–636.

Sparapani, R., Spanbauer, C., & McCulloch, R. (2021). Nonparametric Machine
Learning and Efficient Computation with Bayesian Additive Regression Trees:
The BART R Package. *Journal of Statistical Software*, 97(1), 1–66.

