An example of fitting multiclass classification models using
tidymodels, a stacking ensemble using the
stacks package and using ensModelVis for
visualising the fitted models including a majority vote ensemble.
library(ensModelVis)Packages we will need:
library(tidymodels)
library(stacks)
library(stringr)
library(discrim) # for LDADataset: iris data, split the training data, scale.
data(iris)
train <- iris %>% rename(Response = Species) %>% relocate(Response)
set.seed(1979)
tr <- initial_split(train, prop = .5, strata = Response)
train_data <- training(tr)
test_data <- testing(tr)
mn <- apply(train_data %>% select(-Response), 2, mean)
sd <- apply(train_data %>% select(-Response), 2, sd)
train_data[,-c(1)] <- sweep(train_data[,-c(1)] ,2, mn, "-")
train_data[,-c(1)] <- sweep(train_data[,-c(1)] ,2, sd, "/")
test_data[,-c(1)] <- sweep(test_data[,-c(1)] ,2, mn, "-")
test_data[,-c(1)] <- sweep(test_data[,-c(1)] ,2, sd, "/")Set the recipe for stacks
spec_rec <- recipe(Response ~ ., data = train_data)
spec_wflow <-
workflow() %>%
add_recipe(spec_rec)
ctrl_grid <- control_stack_grid()
ctrl_res <- control_stack_resamples()
folds <- train_data %>% vfold_cv(v = 10, strata = Response)Try some classifiers:
nnet_mod <-
mlp(hidden_units = tune(),
penalty = tune(),
epochs = tune()
) %>%
set_mode("classification") %>%
set_engine("nnet")
nnet_wf <-
spec_wflow %>%
add_model(nnet_mod)
nnet_res <-
nnet_wf %>%
tune_grid(
resamples = folds,
grid = 10,
control = ctrl_grid
)
# ===================================
lasso_reg_grid <- tibble(penalty = 10^seq(-8, -1, length.out = 10))
en_mod <-
multinom_reg(penalty = tune(),
mixture = 0.5) %>%
set_engine("glmnet") %>%
set_mode("classification")
en_wf <-
spec_wflow %>%
add_model(en_mod)
en_res <-
en_wf %>%
tune_grid(
resamples = folds,
grid = lasso_reg_grid,
control = ctrl_grid
)
# ===================================
lda_mod <-
discrim_linear(
) %>%
set_engine("MASS") %>%
set_mode("classification")
lda_wf <- spec_wflow %>%
add_model(lda_mod)
lda_res <-
fit_resamples(
lda_wf,
resamples = folds,
control = ctrl_res
)
# ==================================
rf_mod <-
rand_forest(
mtry = floor(sqrt(ncol(train) - 1)),
trees = 500
) %>%
set_engine("ranger") %>%
set_mode("classification")
rf_wf <- spec_wflow %>%
add_model(rf_mod)
rf_res <-
rf_wf %>%
fit_resamples(
resamples = folds,
control = ctrl_res
)model_st <-
stacks() %>%
add_candidates(lda_res) %>%
add_candidates(nnet_res) %>%
add_candidates(rf_res) %>%
add_candidates(en_res) %>%
blend_predictions() %>%
fit_members()
#> Warning: Predictions from 15 candidates were identical to those from existing
#> candidates and were removed from the data stack.Predict with new data: class and probability.
select <- dplyr::select
ens_pred <-
test_data %>%
select(Response) %>%
bind_cols(
predict(
model_st,
test_data,
type = "class",
members = TRUE
)
)
ens_prob <-
test_data %>%
select(Response) %>%
bind_cols(
predict(
model_st,
test_data,
type = "prob",
members = TRUE
)
)Rename to get nicer graphs.
names(ens_pred) <- str_remove(names(ens_pred), ".pred_class_")
names(ens_pred) <- str_remove(names(ens_pred), "res_1_")
names(ens_pred) <- str_remove(names(ens_pred), "_1")
ens_pred <- ens_pred %>% rename(stack = .pred_class)
names(ens_prob) <- str_remove(names(ens_prob), ".pred_")
names(ens_prob) <- str_remove(names(ens_prob), "res_1_")
names(ens_prob) <- str_remove(names(ens_prob), "_1")
names(ens_prob)[2:4] <- str_c(names(ens_prob)[2:4], "_stack")Calculate AUC
auc <- ens_prob %>%
mutate(id = 1:nrow(ens_prob)) %>%
pivot_longer(-c(Response, id)) %>%
mutate(type = substr(name, 1, 3),
name = str_remove(name,"setosa_"),
name = str_remove(name,"versicolor_"),
name = str_remove(name,"virginica_")) %>%
pivot_wider(names_from = type, values_from = value) %>%
group_by(name) %>%
roc_auc(truth = Response,
set:vir)
auc <- auc %>% select(name, .estimate) %>% pivot_wider(names_from = name, values_from = .estimate)Only take the probability of the most likely class:
ens_prob <- ens_prob %>%
mutate(id = 1:nrow(ens_prob)) %>%
pivot_longer(-c(Response, id)) %>%
mutate(type = substr(name, 1, 3),
name = str_remove(name,"setosa_"),
name = str_remove(name,"versicolor_"),
name = str_remove(name,"virginica_")) %>%
group_by(id, name) %>%
summarise(valuemax = max(value)) %>%
ungroup() %>%
pivot_wider(id_cols = id,
names_from = name, values_from = valuemax) %>%
select(-id)
#> `summarise()` has grouped output by 'id'. You can override using the `.groups`
#> argument.Make a plot:
plot_ensemble(ens_pred %>% pull(Response), ens_pred %>% select(-Response))Look at miss-classified observations:
plot_ensemble(ens_pred %>% pull(Response),
ens_pred %>% select(-Response),
incorrect = TRUE)Add transparency by probability:
plot_ensemble(ens_pred %>% pull(Response),
ens_pred %>% select(-Response),
tibble_prob = ens_prob)
plot_ensemble(
ens_pred %>% pull(Response),
ens_pred %>% select(-Response),
tibble_prob = ens_prob,
incorrect = TRUE
)Reorder by AUC rather than accuracy:
auc <- auc[,
names(ens_pred %>%
select(-Response))]
plot_ensemble(
ens_pred %>%
pull(Response),
ens_pred %>% select(-Response),
tibble_prob = ens_prob,
order = auc
)Add majority vote:
maj_vote <- apply(ens_pred %>%
select(-Response),
1,
function(x)
names(which.max(table(x))))
ens_pred <- ens_pred %>% mutate(maj_vote = as.factor(maj_vote))
plot_ensemble(ens_pred %>% pull(Response), ens_pred %>% select(-Response))
plot_ensemble(ens_pred %>% pull(Response),
ens_pred %>% select(-Response),
incorrect = TRUE)Add probability of majority vote:
prob_maj_vote <-
apply(ens_pred %>% select(-Response), 1, function(x)
max(table(x)) / length(x))
ens_prob <- ens_prob %>%
mutate(maj_vote = prob_maj_vote)
plot_ensemble(ens_pred %>% pull(Response),
ens_pred %>% select(-Response),
tibble_prob = ens_prob)
plot_ensemble(
ens_pred %>% pull(Response),
ens_pred %>% select(-Response),
tibble_prob = ens_prob,
incorrect = TRUE
)