## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(collapse = TRUE, comment = "#>",
                      fig.width = 7, fig.height = 4)

## ----simulate-----------------------------------------------------------------
library(MLCausal)

dat <- simulate_ml_data(n_clusters = 20, cluster_size = 25,
                        n_min = 10, seed = 42)
head(dat)
table(dat$z)

## ----ps-----------------------------------------------------------------------
ps <- ml_ps(
  data      = dat,
  treatment = "z",
  covariates = c("x1", "x2", "x3"),
  cluster   = "school_id",
  method    = "mundlak",
  estimand  = "ATT"
)
print(ps)

## ----overlap, fig.alt="Propensity score overlap plot"-------------------------
plot_overlap_ml(ps)

## ----weights------------------------------------------------------------------
dat_w <- ml_weight(ps, estimand = "ATT", stabilize = TRUE, trim = 10)
summary(dat_w$weights)

## ----balance-weight-----------------------------------------------------------
bal <- balance_ml(
  data       = dat_w,
  treatment  = "z",
  covariates = c("x1", "x2", "x3"),
  cluster    = "school_id",
  weights    = "weights"
)
print(bal)

## ----estimate-----------------------------------------------------------------
est <- estimate_att_ml(
  data       = dat_w,
  outcome    = "y",
  treatment  = "z",
  cluster    = "school_id",
  covariates = c("x1", "x2", "x3"),
  weights    = "weights"
)
print(est)

## ----sensitivity--------------------------------------------------------------
sens <- sens_ml(estimate = est$estimate, se = est$se)
sens[sens$crosses_null, ][1, ]

## ----matching-----------------------------------------------------------------
# lambda = 0  → standard PS matching
# lambda = 1  → equal weight on PS distance and cluster-mean balance (default)
# lambda > 1  → prioritise cluster-mean balance over PS proximity

matched <- ml_match(ps, ratio = 1, caliper = 0.5, lambda = 1)
print(matched)

bal_m <- balance_ml(
  data       = matched$data_matched,
  treatment  = "z",
  covariates = c("x1", "x2", "x3"),
  cluster    = "school_id",
  weights    = "match_weight"
)
print(bal_m)

est_m <- estimate_att_ml(
  data       = matched$data_matched,
  outcome    = "y",
  treatment  = "z",
  cluster    = "school_id",
  covariates = c("x1", "x2", "x3"),
  weights    = "match_weight"
)
print(est_m)

