---
title: "Introduction to MLCausal"
author: "MLCausal Team"
date: "`r Sys.Date()`"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{Introduction to MLCausal}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r setup, include = FALSE}
knitr::opts_chunk$set(collapse = TRUE, comment = "#>",
                      fig.width = 7, fig.height = 4)
```

## Overview

**MLCausal** provides a tidy, end-to-end pipeline for causal inference in
clustered (multilevel) observational data — students in schools, patients in
hospitals, employees in firms.

The standard API uses five consistent argument names across every function:

| Argument | Meaning |
|---|---|
| `treatment` | Name of the 0/1 treatment variable |
| `outcome` | Name of the outcome variable |
| `covariates` | Character vector of covariate names |
| `cluster` | Name of the cluster identifier |
| `weights` | Name of the weight variable |

The main workflow:

```
simulate_ml_data()  →  ml_ps()  →  ml_weight() or ml_match()
  →  balance_ml()  →  estimate_att_ml()  →  sens_ml()
```

---

## 1. Simulate Clustered Data

```{r simulate}
library(MLCausal)

dat <- simulate_ml_data(n_clusters = 20, cluster_size = 25,
                        n_min = 10, seed = 42)
head(dat)
table(dat$z)
```

The true ATT is approximately 0.5 (slightly above because treated units are
over-represented in high-effect clusters).

---

## 2. Estimate Propensity Scores

```{r ps}
ps <- ml_ps(
  data      = dat,
  treatment = "z",
  covariates = c("x1", "x2", "x3"),
  cluster   = "school_id",
  method    = "mundlak",
  estimand  = "ATT"
)
print(ps)
```

---

## 3. Check Overlap

```{r overlap, fig.alt="Propensity score overlap plot"}
plot_overlap_ml(ps)
```

---

## 4. Build Inverse Probability Weights

```{r weights}
dat_w <- ml_weight(ps, estimand = "ATT", stabilize = TRUE, trim = 10)
summary(dat_w$weights)
```

---

## 5. Check Balance

```{r balance-weight}
bal <- balance_ml(
  data       = dat_w,
  treatment  = "z",
  covariates = c("x1", "x2", "x3"),
  cluster    = "school_id",
  weights    = "weights"
)
print(bal)
```

Individual-level SMDs are numeric. Cluster-mean SMDs are character strings:
either a formatted number or a descriptive message if not estimable.

---

## 6. Estimate the ATT

```{r estimate}
est <- estimate_att_ml(
  data       = dat_w,
  outcome    = "y",
  treatment  = "z",
  cluster    = "school_id",
  covariates = c("x1", "x2", "x3"),
  weights    = "weights"
)
print(est)
```

---

## 7. Sensitivity Analysis

```{r sensitivity}
sens <- sens_ml(estimate = est$estimate, se = est$se)
sens[sens$crosses_null, ][1, ]
```

---

## Alternative: Dual-Balance Matching

The `lambda` parameter in `ml_match()` is the core innovation: it adds a
cluster-mean balance penalty to the matching distance, so matches are chosen
to improve balance at *both* the individual and cluster-mean levels
simultaneously.

```{r matching}
# lambda = 0  → standard PS matching
# lambda = 1  → equal weight on PS distance and cluster-mean balance (default)
# lambda > 1  → prioritise cluster-mean balance over PS proximity

matched <- ml_match(ps, ratio = 1, caliper = 0.5, lambda = 1)
print(matched)

bal_m <- balance_ml(
  data       = matched$data_matched,
  treatment  = "z",
  covariates = c("x1", "x2", "x3"),
  cluster    = "school_id",
  weights    = "match_weight"
)
print(bal_m)

est_m <- estimate_att_ml(
  data       = matched$data_matched,
  outcome    = "y",
  treatment  = "z",
  cluster    = "school_id",
  covariates = c("x1", "x2", "x3"),
  weights    = "match_weight"
)
print(est_m)
```

Compare `bal` (weighting) and `bal_m` (dual-balance matching): the matching
approach should show lower cluster-mean SMDs, demonstrating the benefit of the
composite distance.

---

## Summary

| Step | Function | Key argument change |
|---|---|---|
| Simulate | `simulate_ml_data()` | `n_min` prevents tiny clusters |
| PS model | `ml_ps()` | `treatment =` (not `treat`) |
| Weight | `ml_weight()` | output column is `weights` |
| Match | `ml_match()` | `lambda =` for dual-balance |
| Balance | `balance_ml()` | `treatment =`; cluster SMD is string not NA |
| Estimate | `estimate_att_ml()` | `treatment =`, `weights =` |
| Sensitivity | `sens_ml()` | default `q` extended to 5 |
