---
title: "variable selection: non-linear data"
output: rmarkdown::html_vignette
params:
  eval: true
vignette: >
  %\VignetteIndexEntry{variable selection: non-linear data}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
library(LBBNN)
has_torch <- requireNamespace("torch", quietly = TRUE) &&
            torch::torch_is_installed()
```

## Generate data

We again generate 1000 samples with 15 features, and make 6 of them relevant for the outcome. This is a more complicated problem
than the previous, as we include some non-linear effects. This time we transform the outcome into a binary variable. 
```{r, eval = has_torch}
i <- 1000
j <- 15
set.seed(42)
torch::torch_manual_seed(42)
X_nl <- matrix(runif(i * j, 0, 0.5), ncol = j)
y_nl <- (- 3 +  0.1 * log(abs(X_nl[, 1])) + 3 * cos(X_nl[, 2]) 
             + 2 * X_nl[, 3] * X_nl[, 4] + X_nl[, 5] - 
               X_nl[, 6] ** 2 + rnorm(i, sd = 0.1))
y <- c()
# change y to 0 and 1
y[y_nl > median(y_nl)] <- 1
y[y_nl <= median(y_nl)] <- 0
sim_data_nl <- as.data.frame(X_nl)
sim_data_nl <- cbind(sim_data_nl, y)
loaders_nl <- get_dataloaders(sim_data_nl, train_proportion = 0.9,
                           train_batch_size = 450, test_batch_size = 100,
                           standardize = FALSE)
train_loader_nl <- loaders_nl$train_loader
test_loader_nl  <- loaders_nl$test_loader
```

## Define hyperparameters and the model object

We use the same architecture as in the example with linear data. For this example, 
we use normalizing flows in the variational distribution. 
```{r, eval = has_torch}
problem <- "binary classification"
sizes <- c(j, 5, 5, 1) 
incl_priors <- c(0.5, 0.5, 0.5) 
stds <- c(1, 1, 1) 
incl_inits <- 'polarized'
device <- "cpu" 
model_nl <- lbbnn_net(problem_type = problem, sizes = sizes,
                              prior = incl_priors,
                              inclusion_inits = incl_inits, input_skip = TRUE,
                              std = stds, flow = TRUE, dims = c(10, 10, 10),
                              device = device, bias_inclusion_prob = FALSE)
```

## Train and validate the model

```{r, eval = has_torch}
train_lbbnn(epochs = 20, LBBNN = model_nl,
            lr = 0.2, train_dl = train_loader_nl, device = device, verbose = FALSE)

validate_lbbnn(LBBNN = model_nl, num_samples = 2, test_dl = test_loader_nl,
               device = device)
```

## Check the global explanations

```{r,,fig.width=6, fig.height=6, eval = has_torch}
plot(model_nl, type = "global", vertex_size = 7,
     edge_width = 0.4, label_size = 0.4)
```

All the relevant features are included. 

