We simulate a dataset with date-based event times, as one might receive from a clinical database. Patients start treatment and may experience heart failure, be cured, or die. Cured and Death are absorbing states; Heart Failure is a transient intermediate state.
set.seed(2026)
n <- 400
record_id <- seq_len(n)
gender <- sample(c("Male", "Female"), n, replace = TRUE)
trt <- sample(c("Drug A", "Drug B"), n, replace = TRUE)
weight <- round(rnorm(n, mean = 75, sd = 15), 1)
# Treatment start dates spread over 2 years
trt_date <- as.Date("2020-01-01") + sample(0:730, n, replace = TRUE)
# Simulate trajectories manually
heart_failure_date <- rep(as.Date(NA), n)
cured_date <- rep(as.Date(NA), n)
death_date <- rep(as.Date(NA), n)
last_followup_date <- rep(as.Date(NA), n)
for (i in seq_len(n)) {
# From Treatment state, competing events:
# -> Heart Failure (rate depends on weight, treatment)
# -> Cured (rate depends on treatment)
# -> Death (rate depends on weight)
trt_effect <- ifelse(trt[i] == "Drug A", 0.7, 1.0)
wt_effect <- exp((weight[i] - 75) / 50)
t_hf <- rweibull(1, shape = 1.3, scale = 300 * trt_effect * wt_effect)
t_cured <- rweibull(1, shape = 1.5, scale = 250 / trt_effect)
t_death <- rweibull(1, shape = 1.0, scale = 800 * (1 / wt_effect))
first_event <- which.min(c(t_hf, t_cured, t_death))
first_time <- c(t_hf, t_cured, t_death)[first_event]
# Censoring at 3 years
cens_time <- runif(1, 400, 1095)
if (first_time > cens_time) {
# Censored from Treatment state
last_followup_date[i] <- trt_date[i] + round(cens_time)
next
}
if (first_event == 1) {
# Heart Failure reached
heart_failure_date[i] <- trt_date[i] + round(first_time)
# From Heart Failure: -> Cured or -> Death
t_cured2 <- rweibull(1, shape = 1.4, scale = 200 / trt_effect)
t_death2 <- rweibull(1, shape = 1.2, scale = 400 * (1 / wt_effect))
second_event <- which.min(c(t_cured2, t_death2))
second_time <- first_time + c(t_cured2, t_death2)[second_event]
if (second_time > cens_time) {
last_followup_date[i] <- trt_date[i] + round(cens_time)
} else if (second_event == 1) {
cured_date[i] <- trt_date[i] + round(second_time)
} else {
death_date[i] <- trt_date[i] + round(second_time)
}
} else if (first_event == 2) {
cured_date[i] <- trt_date[i] + round(first_time)
} else {
death_date[i] <- trt_date[i] + round(first_time)
}
}
raw_data <- data.frame(
record_id = record_id,
gender = gender,
trt = trt,
weight = weight,
trt_date = trt_date,
cured_date = cured_date,
heart_failure_date = heart_failure_date,
death_date = death_date,
last_followup_date = last_followup_date,
stringsAsFactors = FALSE
)
head(raw_data, 10)
#> record_id gender trt weight trt_date cured_date heart_failure_date
#> 1 1 Male Drug A 95.1 2021-04-26 <NA> <NA>
#> 2 2 Male Drug A 83.3 2021-03-10 <NA> 2021-07-10
#> 3 3 Male Drug A 59.2 2021-12-07 2022-03-09 <NA>
#> 4 4 Female Drug B 72.1 2020-03-16 2020-05-07 <NA>
#> 5 5 Male Drug B 61.4 2021-03-02 2021-07-17 2021-04-11
#> 6 6 Male Drug B 45.0 2021-06-06 <NA> 2021-06-23
#> 7 7 Male Drug A 103.6 2020-04-04 2021-01-04 <NA>
#> 8 8 Female Drug A 98.5 2020-07-08 2021-08-09 2021-01-26
#> 9 9 Female Drug B 77.1 2021-03-08 <NA> <NA>
#> 10 10 Female Drug B 53.7 2020-07-13 2020-10-11 <NA>
#> death_date last_followup_date
#> 1 2021-07-05 <NA>
#> 2 2021-11-21 <NA>
#> 3 <NA> <NA>
#> 4 <NA> <NA>
#> 5 <NA> <NA>
#> 6 2021-09-28 <NA>
#> 7 <NA> <NA>
#> 8 <NA> <NA>
#> 9 2021-05-17 <NA>
#> 10 <NA> <NA>
Convert date columns to days since treatment start.
dat <- data.frame(
record_id = raw_data$record_id,
gender = as.integer(raw_data$gender == "Male"),
trt = as.integer(raw_data$trt == "Drug A"),
weight = raw_data$weight,
time_HeartFailure = as.numeric(
difftime(raw_data$heart_failure_date, raw_data$trt_date, units = "days")
),
time_Cured = as.numeric(
difftime(raw_data$cured_date, raw_data$trt_date, units = "days")
),
time_Death = as.numeric(
difftime(raw_data$death_date, raw_data$trt_date, units = "days")
),
time_censored = as.numeric(
difftime(raw_data$last_followup_date, raw_data$trt_date, units = "days")
),
stringsAsFactors = FALSE
)
head(dat, 10)
#> record_id gender trt weight time_HeartFailure time_Cured time_Death
#> 1 1 1 1 95.1 NA NA 70
#> 2 2 1 1 83.3 122 NA 256
#> 3 3 1 1 59.2 NA 92 NA
#> 4 4 0 0 72.1 NA 52 NA
#> 5 5 1 0 61.4 40 137 NA
#> 6 6 1 0 45.0 17 NA 114
#> 7 7 1 1 103.6 NA 275 NA
#> 8 8 0 1 98.5 202 397 NA
#> 9 9 0 0 77.1 NA NA 70
#> 10 10 0 0 53.7 NA 90 NA
#> time_censored
#> 1 NA
#> 2 NA
#> 3 NA
#> 4 NA
#> 5 NA
#> 6 NA
#> 7 NA
#> 8 NA
#> 9 NA
#> 10 NA
Quick summary of event counts:
cat("Total patients:", nrow(dat), "\n")
#> Total patients: 400
cat("Heart failure observed:", sum(!is.na(dat$time_HeartFailure)), "\n")
#> Heart failure observed: 192
cat("Cured:", sum(!is.na(dat$time_Cured)), "\n")
#> Cured: 259
cat("Death:", sum(!is.na(dat$time_Death)), "\n")
#> Death: 130
cat("Censored (no absorbing state):", sum(!is.na(dat$time_censored)), "\n")
#> Censored (no absorbing state): 11
library(RFmstate)
ms <- define_multistate(
state_names = c("Treatment", "HeartFailure", "Cured", "Death"),
absorbing = c("Cured", "Death"),
transitions = list(
Treatment = c("HeartFailure", "Cured", "Death"),
HeartFailure = c("Cured", "Death")
)
)
print(ms)
#> Multistate Structure
#> States: Treatment -> HeartFailure -> Cured -> Death
#> Absorbing: Cured, Death
#> Transitions: 5
#> 1: Treatment -> HeartFailure
#> 2: Treatment -> Cured
#> 3: Treatment -> Death
#> 4: HeartFailure -> Cured
#> 5: HeartFailure -> Death
msdata <- prepare_data(
data = dat,
id = "record_id",
structure = ms,
time_map = list(
HeartFailure = "time_HeartFailure",
Cured = "time_Cured",
Death = "time_Death"
),
censor_col = "time_censored",
covariates = c("gender", "trt", "weight")
)
print(msdata)
#> Multistate Data (msdata)
#> Patients: 400
#> Intervals: 592
#> Transitions observed: 581
#> Censored intervals: 11
#> States: Treatment, HeartFailure, Cured, Death
#>
#> Transition counts:
#> to
#> from Cured Death HeartFailure
#> HeartFailure 120 61 0
#> Treatment 139 69 192
plot_transition_diagram(ms, msdata)
State transition diagram with event counts.
aj <- aalen_johansen(msdata)
print(aj)
#> Aalen-Johansen Estimate
#> Time range: [1, 644]
#> Event times: 312
#> States: Treatment, HeartFailure, Cured, Death
#>
#> Event counts per transition:
#> from to n_events
#> Treatment HeartFailure 192
#> Treatment Cured 139
#> Treatment Death 69
#> HeartFailure Cured 120
#> HeartFailure Death 61
#>
#> Final state occupation probabilities:
#> Treatment: 0
#> HeartFailure: 0.011
#> Cured: 0.6624
#> Death: 0.3267
plot(aj, type = "state_occupation")
State occupation probabilities (Aalen-Johansen).
plot(aj, type = "cumulative_hazard")
Nelson-Aalen cumulative hazards by transition.
plot(aj, type = "stacked_transition_prob")
Transition probabilities from Treatment state (AJ).
plot(aj, type = "transition_intensity")
Transition intensities over time (AJ).
fit <- rfmstate(
msdata,
covariates = c("gender", "trt", "weight"),
num.trees = 500,
min.node.size = 15,
seed = 2026
)
print(fit)
#> Random Forest Multistate Model
#> Call: rfmstate(msdata = msdata, covariates = c("gender", "trt", "weight"),
#> num.trees = 500, min.node.size = 15, seed = 2026)
#>
#> Covariates: gender, trt, weight
#> Parameters:
#> num.trees: 500
#> mtry: 1
#> min.node.size: 15
#>
#> Models fitted per origin state:
#> Treatment (n=400): -> HeartFailure, Cured, Death
#> HeartFailure (n=192): -> Cured, Death
s <- summary(fit)
imp <- importance(fit)
print(imp)
#> Feature Importance per Transition
#> ============================================================
#>
#> Treatment -> HeartFailure Treatment -> Cured Treatment -> Death
#> gender -0.0038 -0.0041 0.0133
#> trt 0.0304 0.0614 0.0165
#> weight 0.0292 -0.0036 0.0082
#> HeartFailure -> Cured HeartFailure -> Death
#> gender -0.0027 0.0022
#> trt 0.0496 0.0180
#> weight -0.0132 0.0206
#>
#> Top variables per transition:
#> Treatment -> HeartFailure: trt (0.0304)
#> Treatment -> Cured: trt (0.0614)
#> Treatment -> Death: trt (0.0165)
#> HeartFailure -> Cured: trt (0.0496)
#> HeartFailure -> Death: weight (0.0206)
plot(imp, type = "barplot")
Feature importance per transition.
plot(imp, type = "heatmap")
Feature importance heatmap.
new_patients <- data.frame(
gender = c(1, 0, 1),
trt = c(1, 0, 1),
weight = c(65, 90, 75)
)
rownames(new_patients) <- c("Light male, Drug A",
"Heavy female, Drug B",
"Average male, Drug A")
print(new_patients)
#> gender trt weight
#> Light male, Drug A 1 1 65
#> Heavy female, Drug B 0 0 90
#> Average male, Drug A 1 1 75
pred <- predict(fit, newdata = new_patients, times = seq(30, 1000, by = 30))
print(pred)
#> RF Multistate Predictions
#> Subjects: 3
#> Time points: 33
#> Time range: [30, 990]
#> States: Treatment, HeartFailure, Cured, Death
plot(pred, type = "state_occupation", subject = 1)
Predicted state occupation: light male on Drug A.
plot(pred, type = "state_occupation", subject = 2)
Predicted state occupation: heavy female on Drug B.
plot(pred, type = "state_occupation", subject = 3)
Predicted state occupation: average male on Drug A.
plot(pred, type = "transition_prob", subject = 1)
Predicted transition probabilities: light male on Drug A.
diag <- diagnose(fit)
print(diag)
#> RF Multistate Model Diagnostics
#> ============================================================
#>
#> OOB Prediction Error:
#> ----------------------------------------
#> Treatment -> HeartFailure 0.4178
#> Treatment -> Cured 0.4319
#> Treatment -> Death 0.4828
#> HeartFailure -> Cured 0.4367
#> HeartFailure -> Death 0.4626
#>
#> Concordance Index (C-index):
#> ----------------------------------------
#> Treatment -> HeartFailure 0.6804
#> Treatment -> Cured 0.7098
#> Treatment -> Death 0.7603
#> HeartFailure -> Cured 0.6796
#> HeartFailure -> Death 0.7464
#>
#> Bias-Variance Decomposition:
#> ------------------------------------------------------------
#> Transition Bias Var MSE
#> ------------------------------------------------------------
#> Treatment -> HeartFailure 0.0391 0.0057 0.1982
#> Treatment -> Cured 0.0634 0.0063 0.1868
#> Treatment -> Death 0.0250 0.0022 0.1112
#> HeartFailure -> Cured 0.0384 0.0090 0.2103
#> HeartFailure -> Death 0.0220 0.0037 0.1482
plot(diag, type = "brier")
Time-dependent Brier score by transition.
plot(diag, type = "concordance")
Concordance index by transition.
plot(diag, type = "bias_variance")
Bias-variance decomposition by transition.
We can compare predicted outcomes between Drug A and Drug B for an average patient.
drug_a <- data.frame(gender = 1, trt = 1, weight = 75)
drug_b <- data.frame(gender = 1, trt = 0, weight = 75)
pred_a <- predict(fit, newdata = drug_a, times = seq(30, 1000, by = 30))
pred_b <- predict(fit, newdata = drug_b, times = seq(30, 1000, by = 30))
times <- pred_a$time
states <- ms$state_names
par(mfrow = c(2, 2), mar = c(4, 4, 3, 1))
cols <- c("#1b9e77", "#d95f02")
for (j in seq_along(states)) {
occ_a <- pred_a$state_occ[1, j, ]
occ_b <- pred_b$state_occ[1, j, ]
plot(times, occ_a, type = "l", col = cols[1], lwd = 2,
ylim = c(0, max(c(occ_a, occ_b)) * 1.1),
xlab = "Days", ylab = "Probability",
main = states[j])
lines(times, occ_b, col = cols[2], lwd = 2)
legend("topright", legend = c("Drug A", "Drug B"),
col = cols, lwd = 2, bty = "n", cex = 0.8)
}