1. Simulate Raw Clinical Data

We simulate a dataset with date-based event times, as one might receive from a clinical database. Patients start treatment and may experience heart failure, be cured, or die. Cured and Death are absorbing states; Heart Failure is a transient intermediate state.

set.seed(2026)
n <- 400

record_id <- seq_len(n)
gender <- sample(c("Male", "Female"), n, replace = TRUE)
trt <- sample(c("Drug A", "Drug B"), n, replace = TRUE)
weight <- round(rnorm(n, mean = 75, sd = 15), 1)

# Treatment start dates spread over 2 years
trt_date <- as.Date("2020-01-01") + sample(0:730, n, replace = TRUE)

# Simulate trajectories manually
heart_failure_date <- rep(as.Date(NA), n)
cured_date <- rep(as.Date(NA), n)
death_date <- rep(as.Date(NA), n)
last_followup_date <- rep(as.Date(NA), n)

for (i in seq_len(n)) {
  # From Treatment state, competing events:
  #   -> Heart Failure (rate depends on weight, treatment)
  #   -> Cured (rate depends on treatment)
  #   -> Death (rate depends on weight)
  trt_effect <- ifelse(trt[i] == "Drug A", 0.7, 1.0)
  wt_effect <- exp((weight[i] - 75) / 50)

  t_hf <- rweibull(1, shape = 1.3, scale = 300 * trt_effect * wt_effect)
  t_cured <- rweibull(1, shape = 1.5, scale = 250 / trt_effect)
  t_death <- rweibull(1, shape = 1.0, scale = 800 * (1 / wt_effect))

  first_event <- which.min(c(t_hf, t_cured, t_death))
  first_time <- c(t_hf, t_cured, t_death)[first_event]

  # Censoring at 3 years
  cens_time <- runif(1, 400, 1095)

  if (first_time > cens_time) {
    # Censored from Treatment state
    last_followup_date[i] <- trt_date[i] + round(cens_time)
    next
  }

  if (first_event == 1) {
    # Heart Failure reached
    heart_failure_date[i] <- trt_date[i] + round(first_time)

    # From Heart Failure: -> Cured or -> Death
    t_cured2 <- rweibull(1, shape = 1.4, scale = 200 / trt_effect)
    t_death2 <- rweibull(1, shape = 1.2, scale = 400 * (1 / wt_effect))

    second_event <- which.min(c(t_cured2, t_death2))
    second_time <- first_time + c(t_cured2, t_death2)[second_event]

    if (second_time > cens_time) {
      last_followup_date[i] <- trt_date[i] + round(cens_time)
    } else if (second_event == 1) {
      cured_date[i] <- trt_date[i] + round(second_time)
    } else {
      death_date[i] <- trt_date[i] + round(second_time)
    }
  } else if (first_event == 2) {
    cured_date[i] <- trt_date[i] + round(first_time)
  } else {
    death_date[i] <- trt_date[i] + round(first_time)
  }
}

raw_data <- data.frame(
  record_id = record_id,
  gender = gender,
  trt = trt,
  weight = weight,
  trt_date = trt_date,
  cured_date = cured_date,
  heart_failure_date = heart_failure_date,
  death_date = death_date,
  last_followup_date = last_followup_date,
  stringsAsFactors = FALSE
)

head(raw_data, 10)
#>    record_id gender    trt weight   trt_date cured_date heart_failure_date
#> 1          1   Male Drug A   95.1 2021-04-26       <NA>               <NA>
#> 2          2   Male Drug A   83.3 2021-03-10       <NA>         2021-07-10
#> 3          3   Male Drug A   59.2 2021-12-07 2022-03-09               <NA>
#> 4          4 Female Drug B   72.1 2020-03-16 2020-05-07               <NA>
#> 5          5   Male Drug B   61.4 2021-03-02 2021-07-17         2021-04-11
#> 6          6   Male Drug B   45.0 2021-06-06       <NA>         2021-06-23
#> 7          7   Male Drug A  103.6 2020-04-04 2021-01-04               <NA>
#> 8          8 Female Drug A   98.5 2020-07-08 2021-08-09         2021-01-26
#> 9          9 Female Drug B   77.1 2021-03-08       <NA>               <NA>
#> 10        10 Female Drug B   53.7 2020-07-13 2020-10-11               <NA>
#>    death_date last_followup_date
#> 1  2021-07-05               <NA>
#> 2  2021-11-21               <NA>
#> 3        <NA>               <NA>
#> 4        <NA>               <NA>
#> 5        <NA>               <NA>
#> 6  2021-09-28               <NA>
#> 7        <NA>               <NA>
#> 8        <NA>               <NA>
#> 9  2021-05-17               <NA>
#> 10       <NA>               <NA>

2. Compute Time-to-Event from Treatment Date

Convert date columns to days since treatment start.

dat <- data.frame(
  record_id = raw_data$record_id,
  gender = as.integer(raw_data$gender == "Male"),
  trt = as.integer(raw_data$trt == "Drug A"),
  weight = raw_data$weight,
  time_HeartFailure = as.numeric(
    difftime(raw_data$heart_failure_date, raw_data$trt_date, units = "days")
  ),
  time_Cured = as.numeric(
    difftime(raw_data$cured_date, raw_data$trt_date, units = "days")
  ),
  time_Death = as.numeric(
    difftime(raw_data$death_date, raw_data$trt_date, units = "days")
  ),
  time_censored = as.numeric(
    difftime(raw_data$last_followup_date, raw_data$trt_date, units = "days")
  ),
  stringsAsFactors = FALSE
)

head(dat, 10)
#>    record_id gender trt weight time_HeartFailure time_Cured time_Death
#> 1          1      1   1   95.1                NA         NA         70
#> 2          2      1   1   83.3               122         NA        256
#> 3          3      1   1   59.2                NA         92         NA
#> 4          4      0   0   72.1                NA         52         NA
#> 5          5      1   0   61.4                40        137         NA
#> 6          6      1   0   45.0                17         NA        114
#> 7          7      1   1  103.6                NA        275         NA
#> 8          8      0   1   98.5               202        397         NA
#> 9          9      0   0   77.1                NA         NA         70
#> 10        10      0   0   53.7                NA         90         NA
#>    time_censored
#> 1             NA
#> 2             NA
#> 3             NA
#> 4             NA
#> 5             NA
#> 6             NA
#> 7             NA
#> 8             NA
#> 9             NA
#> 10            NA

Quick summary of event counts:

cat("Total patients:", nrow(dat), "\n")
#> Total patients: 400
cat("Heart failure observed:", sum(!is.na(dat$time_HeartFailure)), "\n")
#> Heart failure observed: 192
cat("Cured:", sum(!is.na(dat$time_Cured)), "\n")
#> Cured: 259
cat("Death:", sum(!is.na(dat$time_Death)), "\n")
#> Death: 130
cat("Censored (no absorbing state):", sum(!is.na(dat$time_censored)), "\n")
#> Censored (no absorbing state): 11

3. Define Multistate Structure

library(RFmstate)

ms <- define_multistate(
  state_names = c("Treatment", "HeartFailure", "Cured", "Death"),
  absorbing = c("Cured", "Death"),
  transitions = list(
    Treatment = c("HeartFailure", "Cured", "Death"),
    HeartFailure = c("Cured", "Death")
  )
)
print(ms)
#> Multistate Structure
#>   States: Treatment -> HeartFailure -> Cured -> Death 
#>   Absorbing: Cured, Death 
#>   Transitions: 5 
#>     1: Treatment -> HeartFailure
#>     2: Treatment -> Cured
#>     3: Treatment -> Death
#>     4: HeartFailure -> Cured
#>     5: HeartFailure -> Death

4. Prepare Multistate Data

msdata <- prepare_data(
  data = dat,
  id = "record_id",
  structure = ms,
  time_map = list(
    HeartFailure = "time_HeartFailure",
    Cured = "time_Cured",
    Death = "time_Death"
  ),
  censor_col = "time_censored",
  covariates = c("gender", "trt", "weight")
)
print(msdata)
#> Multistate Data (msdata)
#>   Patients: 400 
#>   Intervals: 592 
#>   Transitions observed: 581 
#>   Censored intervals: 11 
#>   States: Treatment, HeartFailure, Cured, Death 
#> 
#> Transition counts:
#>               to
#> from           Cured Death HeartFailure
#>   HeartFailure   120    61            0
#>   Treatment      139    69          192

5. Transition Diagram

plot_transition_diagram(ms, msdata)
State transition diagram with event counts.

State transition diagram with event counts.

6. Aalen-Johansen Nonparametric Estimates

aj <- aalen_johansen(msdata)
print(aj)
#> Aalen-Johansen Estimate
#>   Time range: [1, 644]
#>   Event times: 312 
#>   States: Treatment, HeartFailure, Cured, Death 
#> 
#> Event counts per transition:
#>          from           to n_events
#>     Treatment HeartFailure      192
#>     Treatment        Cured      139
#>     Treatment        Death       69
#>  HeartFailure        Cured      120
#>  HeartFailure        Death       61
#> 
#> Final state occupation probabilities:
#>   Treatment: 0
#>   HeartFailure: 0.011
#>   Cured: 0.6624
#>   Death: 0.3267
plot(aj, type = "state_occupation")
State occupation probabilities (Aalen-Johansen).

State occupation probabilities (Aalen-Johansen).

plot(aj, type = "cumulative_hazard")
Nelson-Aalen cumulative hazards by transition.

Nelson-Aalen cumulative hazards by transition.

plot(aj, type = "stacked_transition_prob")
Transition probabilities from Treatment state (AJ).

Transition probabilities from Treatment state (AJ).

plot(aj, type = "transition_intensity")
Transition intensities over time (AJ).

Transition intensities over time (AJ).

7. Fit Random Forest Model

fit <- rfmstate(
  msdata,
  covariates = c("gender", "trt", "weight"),
  num.trees = 500,
  min.node.size = 15,
  seed = 2026
)
print(fit)
#> Random Forest Multistate Model
#> Call: rfmstate(msdata = msdata, covariates = c("gender", "trt", "weight"), 
#>     num.trees = 500, min.node.size = 15, seed = 2026)
#> 
#> Covariates: gender, trt, weight 
#> Parameters:
#>   num.trees: 500 
#>   mtry: 1 
#>   min.node.size: 15 
#> 
#> Models fitted per origin state:
#>   Treatment (n=400): -> HeartFailure, Cured, Death
#>   HeartFailure (n=192): -> Cured, Death

8. Model Summary

s <- summary(fit)

9. Feature Importance

imp <- importance(fit)
print(imp)
#> Feature Importance per Transition
#> ============================================================ 
#> 
#>        Treatment -> HeartFailure Treatment -> Cured Treatment -> Death
#> gender                   -0.0038            -0.0041             0.0133
#> trt                       0.0304             0.0614             0.0165
#> weight                    0.0292            -0.0036             0.0082
#>        HeartFailure -> Cured HeartFailure -> Death
#> gender               -0.0027                0.0022
#> trt                   0.0496                0.0180
#> weight               -0.0132                0.0206
#> 
#> Top variables per transition:
#>   Treatment -> HeartFailure: trt (0.0304)
#>   Treatment -> Cured: trt (0.0614)
#>   Treatment -> Death: trt (0.0165)
#>   HeartFailure -> Cured: trt (0.0496)
#>   HeartFailure -> Death: weight (0.0206)
plot(imp, type = "barplot")
Feature importance per transition.

Feature importance per transition.

plot(imp, type = "heatmap")
Feature importance heatmap.

Feature importance heatmap.

10. Predict for New Patients

new_patients <- data.frame(
  gender = c(1, 0, 1),
  trt = c(1, 0, 1),
  weight = c(65, 90, 75)
)
rownames(new_patients) <- c("Light male, Drug A",
                             "Heavy female, Drug B",
                             "Average male, Drug A")
print(new_patients)
#>                      gender trt weight
#> Light male, Drug A        1   1     65
#> Heavy female, Drug B      0   0     90
#> Average male, Drug A      1   1     75

pred <- predict(fit, newdata = new_patients, times = seq(30, 1000, by = 30))
print(pred)
#> RF Multistate Predictions
#>   Subjects: 3 
#>   Time points: 33 
#>   Time range: [30, 990]
#>   States: Treatment, HeartFailure, Cured, Death
plot(pred, type = "state_occupation", subject = 1)
Predicted state occupation: light male on Drug A.

Predicted state occupation: light male on Drug A.

plot(pred, type = "state_occupation", subject = 2)
Predicted state occupation: heavy female on Drug B.

Predicted state occupation: heavy female on Drug B.

plot(pred, type = "state_occupation", subject = 3)
Predicted state occupation: average male on Drug A.

Predicted state occupation: average male on Drug A.

plot(pred, type = "transition_prob", subject = 1)
Predicted transition probabilities: light male on Drug A.

Predicted transition probabilities: light male on Drug A.

11. Diagnostics

diag <- diagnose(fit)
print(diag)
#> RF Multistate Model Diagnostics
#> ============================================================ 
#> 
#> OOB Prediction Error:
#> ---------------------------------------- 
#>   Treatment -> HeartFailure 0.4178
#>   Treatment -> Cured        0.4319
#>   Treatment -> Death        0.4828
#>   HeartFailure -> Cured     0.4367
#>   HeartFailure -> Death     0.4626
#> 
#> Concordance Index (C-index):
#> ---------------------------------------- 
#>   Treatment -> HeartFailure 0.6804
#>   Treatment -> Cured        0.7098
#>   Treatment -> Death        0.7603
#>   HeartFailure -> Cured     0.6796
#>   HeartFailure -> Death     0.7464
#> 
#> Bias-Variance Decomposition:
#> ------------------------------------------------------------ 
#>   Transition                    Bias      Var      MSE
#> ------------------------------------------------------------ 
#>   Treatment -> HeartFailure   0.0391   0.0057   0.1982
#>   Treatment -> Cured          0.0634   0.0063   0.1868
#>   Treatment -> Death          0.0250   0.0022   0.1112
#>   HeartFailure -> Cured       0.0384   0.0090   0.2103
#>   HeartFailure -> Death       0.0220   0.0037   0.1482
plot(diag, type = "brier")
Time-dependent Brier score by transition.

Time-dependent Brier score by transition.

plot(diag, type = "concordance")
Concordance index by transition.

Concordance index by transition.

plot(diag, type = "bias_variance")
Bias-variance decomposition by transition.

Bias-variance decomposition by transition.

12. Comparing Treatment Arms

We can compare predicted outcomes between Drug A and Drug B for an average patient.

drug_a <- data.frame(gender = 1, trt = 1, weight = 75)
drug_b <- data.frame(gender = 1, trt = 0, weight = 75)

pred_a <- predict(fit, newdata = drug_a, times = seq(30, 1000, by = 30))
pred_b <- predict(fit, newdata = drug_b, times = seq(30, 1000, by = 30))

times <- pred_a$time
states <- ms$state_names

par(mfrow = c(2, 2), mar = c(4, 4, 3, 1))
cols <- c("#1b9e77", "#d95f02")
for (j in seq_along(states)) {
  occ_a <- pred_a$state_occ[1, j, ]
  occ_b <- pred_b$state_occ[1, j, ]
  plot(times, occ_a, type = "l", col = cols[1], lwd = 2,
       ylim = c(0, max(c(occ_a, occ_b)) * 1.1),
       xlab = "Days", ylab = "Probability",
       main = states[j])
  lines(times, occ_b, col = cols[2], lwd = 2)
  legend("topright", legend = c("Drug A", "Drug B"),
         col = cols, lwd = 2, bty = "n", cex = 0.8)
}