---
title: "tidymodels / parsnip Integration"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{tidymodels / parsnip Integration}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r setup, include=FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment  = "#>",
  eval     = requireNamespace("parsnip", quietly = TRUE)
)
```

ggmlR registers a `"ggml"` engine for `parsnip::mlp()`, giving you
GPU-accelerated neural networks inside the tidymodels ecosystem —
resampling, tuning, workflows, and recipes all work out of the box.

```{r}
library(ggmlR)
library(parsnip)
```

---

## 1. Classification

```{r}
spec <- mlp(
  hidden_units = c(64L, 32L),
  epochs       = 20L,
  dropout      = 0.1
) |>
  set_engine("ggml") |>
  set_mode("classification")

fit_obj <- fit(spec, Species ~ ., data = iris)

# Class predictions
preds <- predict(fit_obj, new_data = iris)
head(preds)

# Probability predictions
probs <- predict(fit_obj, new_data = iris, type = "prob")
head(probs)

# Accuracy
cat(sprintf("Accuracy: %.4f\n", mean(preds$.pred_class == iris$Species)))
```

---

## 2. Regression

```{r}
spec_reg <- mlp(
  hidden_units = c(64L, 32L),
  epochs       = 50L
) |>
  set_engine("ggml") |>
  set_mode("regression")

fit_reg <- fit(spec_reg, mpg ~ ., data = mtcars)

preds_reg <- predict(fit_reg, new_data = mtcars)
head(preds_reg)
```

---

## 3. Engine parameters

The `ggml` engine maps standard parsnip arguments to ggmlR internals:

| parsnip | ggmlR | Default |
|---------|-------|---------|
| `hidden_units` | `hidden_layers` | `c(128, 64)` |
| `epochs` | `epochs` | `10` |
| `dropout` | `dropout` | `0.2` |
| `activation` | `activation` | `"relu"` |
| `learn_rate` | `learn_rate` | `0.001` |

```{r}
# Customize architecture
spec_custom <- mlp(
  hidden_units = c(128L, 64L, 32L),
  epochs       = 30L,
  dropout      = 0.3,
  activation   = "relu"
) |>
  set_engine("ggml") |>
  set_mode("classification")
```

---

## 4. Resampling with rsample

```{r, eval=FALSE}
library(rsample)

folds <- vfold_cv(iris, v = 5L)

spec <- mlp(hidden_units = c(32L), epochs = 10L) |>
  set_engine("ggml") |>
  set_mode("classification")

library(tune)
library(yardstick)
library(workflows)

wf <- workflow() |>
  add_model(spec) |>
  add_formula(Species ~ .)

results <- fit_resamples(wf, resamples = folds)
collect_metrics(results)
```

---

## 5. Recipes for preprocessing

ggmlR accepts only numeric features. Use `recipes` to handle factors,
missing values, and scaling:

```{r, eval=FALSE}
library(recipes)
library(workflows)

rec <- recipe(Species ~ ., data = iris) |>
  step_normalize(all_numeric_predictors())

spec <- mlp(hidden_units = c(32L), epochs = 10L) |>
  set_engine("ggml") |>
  set_mode("classification")

wf <- workflow() |>
  add_recipe(rec) |>
  add_model(spec)

fit_obj <- fit(wf, data = iris)
predict(fit_obj, new_data = iris)
```

For datasets with factors:

```{r, eval=FALSE}
rec <- recipe(Species ~ ., data = iris) |>
  step_dummy(all_nominal_predictors()) |>
  step_normalize(all_numeric_predictors())
```

---

## 6. Hyperparameter tuning

```{r, eval=FALSE}
library(tune)
library(dials)
library(workflows)

spec <- mlp(
  hidden_units = tune(),
  epochs       = tune(),
  dropout      = tune()
) |>
  set_engine("ggml") |>
  set_mode("classification")

wf <- workflow() |>
  add_model(spec) |>
  add_formula(Species ~ .)

grid <- grid_regular(
  hidden_units(range = c(16L, 128L)),
  epochs(range = c(10L, 50L)),
  dropout(range = c(0, 0.4)),
  levels = 3L
)

folds <- vfold_cv(iris, v = 3L)
results <- tune_grid(wf, resamples = folds, grid = grid)
show_best(results, metric = "accuracy")
```

---

## 7. Comparison with other engines

```{r, eval=FALSE}
library(workflows)
library(workflowsets)

specs <- workflow_set(
  preproc = list(basic = Species ~ .),
  models  = list(
    ggml  = mlp(hidden_units = c(32L), epochs = 20L) |> set_engine("ggml"),
    nnet  = mlp(hidden_units = 32L,    epochs = 200L) |> set_engine("nnet")
  )
) |>
  workflow_map("fit_resamples",
               resamples = vfold_cv(iris, v = 5L))

rank_results(specs, rank_metric = "accuracy")
```

---

## Summary

| Feature | Supported |
|---------|-----------|
| Classification | Yes (`class`, `prob`) |
| Regression | Yes (`numeric`) |
| GPU (Vulkan) | Yes (auto-detected) |
| Recipes / preprocessing | Yes |
| Resampling | Yes |
| Tuning | Yes |
| Workflows | Yes |
