## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.width = 7,
  fig.height = 5
)

## ----load-packages------------------------------------------------------------
library(nmfkc)
library(vars) # For Canada dataset

## ----air-data-prep------------------------------------------------------------
# Load and transform the ts object
d_air <- AirPassengers
d_air_log <- log10(d_air) # Still a ts object

## ----air-degree-cv------------------------------------------------------------
# Evaluate lag orders from 1 to 14
# Note: ts objects are automatically transposed to (Variables x Time) internally
cv_res <- nmfkc.ar.degree.cv(d_air_log, rank = 1, degree = 1:14, epsilon=1e-6, maxit=500000)

# Check the optimal degree
cv_res$degree

# For this example, we will proceed with D=12 (capturing monthly seasonality)
D <- 12

## ----air-model-fit------------------------------------------------------------
# Create matrices for the AR(12) model
a_air <- nmfkc.ar(d_air_log, degree = D, intercept = TRUE)

# Fit the NMF-AR model (Rank=1 for univariate)
res_air <- nmfkc(Y = a_air$Y, A = a_air$A, rank = 1, epsilon = 1e-6, maxit=500000)

# Check goodness of fit
res_air$r.squared

# Check for stationarity (spectral radius < 1)
nmfkc.ar.stationarity(res_air)

## ----air-forecast-------------------------------------------------------------
# Forecast next 2 years (24 months)
h <- 24
pred_res <- nmfkc.ar.predict(x = res_air, Y = a_air$Y, n.ahead = h)

# Convert predictions back to original scale
pred_val <- 10^as.vector(pred_res$pred)
pred_time <- pred_res$time # Future time points generated by the function

# --- Plotting ---
# Setup plot range
xlim_range <- range(c(time(d_air), pred_time))
ylim_range <- range(c(d_air, pred_val))

# 1. Observed data (Black)
plot(d_air, type = "l", col = "black", 
     xlim = xlim_range, ylim = ylim_range, lwd = 1,
     xlab = "Year", ylab = "Air Passengers", main = "NMF-VAR Forecast (h=24)")

# 2. Fitted values during training (Red)
# a_air$Y has column names as time strings; we parse them for plotting
fitted_time <- as.numeric(colnames(res_air$XB))
lines(fitted_time, 10^as.vector(res_air$XB), col = "red", lwd = 2)

# 3. Forecast (Blue)
# Connect the last observed point to the first forecast for a continuous line
last_t <- tail(as.numeric(time(d_air)), 1)
last_y <- tail(as.vector(d_air), 1)
lines(c(last_t, pred_time), c(last_y, pred_val), col = "blue", lwd = 2, lty = 2)

# Add legend
legend("topleft", legend = c("Observed", "Fitted", "Forecast"),
       col = c("black", "red", "blue"), lty = c(1, 1, 2), lwd = 2)

## ----canada-data-prep---------------------------------------------------------
# Load, difference, and normalize
d0_canada <- Canada
dd_canada <- apply(d0_canada, 2, diff) # Returns a matrix (Time x Vars)
dn_canada <- nmfkc.normalize(dd_canada)

# Transpose to (Variables x Time) for NMF
Y0_canada <- t(dn_canada)

# Create matrices for VAR(1)
a_canada <- nmfkc.ar(Y0_canada, degree = 1, intercept = TRUE)

## ----canada-model-fit---------------------------------------------------------
# Fit the NMF-VAR model
res_canada <- nmfkc(Y = a_canada$Y, A = a_canada$A, rank = 2, prefix = "Condition", epsilon = 1e-6)

# R-squared and Stationarity
res_canada$r.squared
nmfkc.ar.stationarity(res_canada)

## ----canada-soft-cluster------------------------------------------------------
# Visualize soft clustering of time trends
barplot(res_canada$B.prob, col = c(2, 3), border = NA,
        main = "Soft Clustering of Economic Conditions",
        xlab = "Time", ylab = "Probability",
        names.arg = colnames(a_canada$Y), las=2, cex.names = 0.5)
legend("topright", legend = colnames(res_canada$X), fill = c(2, 3), bg = "white")

## ----canada-dot---------------------------------------------------------------
# Generate DOT script for graph visualization
dot_script <- nmfkc.ar.DOT(res_canada, intercept = TRUE, threshold=0.01)

# plot(dot_script)  # requires DiagrammeR

