ranger::treeInfo() to parse each decision path.wt ~ mpg + ammutate(mtcars, newam = paste0(am)) and then wt ~ mpg + newamwt ~ mpg + as.factor(am)wt ~ mpg + as.character(am)tidypredict_interval() & tidypredict_sql_interval()Here is a simple ranger() model using the iris dataset:
library(ranger)
model <- ranger::ranger(Species ~ .,data = iris ,num.trees = 100)The SQL translations returns a single SQL CASE WHEN operation. Each decision path is a WHEN statement.
library(tidypredict)
tidypredict_sql(model, dbplyr::simulate_mssql())## <SQL> CASE
## WHEN ((((`Petal.Length`) < 2.45) AND ((`Petal.Width`) < 1.75))) THEN ('setosa')
## WHEN ((((`Petal.Width`) >= 1.75) AND ((`Petal.Length`) < 4.85))) THEN ('virginica')
## WHEN ((((`Petal.Length`) >= 4.85) AND ((`Petal.Width`) >= 1.75))) THEN ('virginica')
## WHEN (((((`Petal.Length`) >= 2.45) AND ((`Petal.Length`) < 5.4)) AND ((`Petal.Width`) < 1.75))) THEN ('versicolor')
## WHEN (((((`Petal.Length`) >= 5.4) AND ((`Petal.Length`) >= 2.45)) AND ((`Petal.Width`) < 1.75))) THEN ('virginica')
## END
Alternatively, use tidypredict_to_column() if the results are the be used or previewed in dplyr.
iris %>%
tidypredict_to_column(model) %>%
head(10)## Sepal.Length Sepal.Width Petal.Length Petal.Width Species fit
## 1 5.1 3.5 1.4 0.2 setosa setosa
## 2 4.9 3.0 1.4 0.2 setosa setosa
## 3 4.7 3.2 1.3 0.2 setosa setosa
## 4 4.6 3.1 1.5 0.2 setosa setosa
## 5 5.0 3.6 1.4 0.2 setosa setosa
## 6 5.4 3.9 1.7 0.4 setosa setosa
## 7 4.6 3.4 1.4 0.3 setosa setosa
## 8 5.0 3.4 1.5 0.2 setosa setosa
## 9 4.4 2.9 1.4 0.2 setosa setosa
## 10 4.9 3.1 1.5 0.1 setosa setosa
The parser is based on the output from the ranger::treeInfo() function. It will return as many decision paths as there are non-NA rows in the prediction field.
treeInfo(model)## nodeID leftChild rightChild splitvarID splitvarName splitval terminal
## 1 0 1 2 4 Petal.Width 1.75 FALSE
## 2 1 3 4 3 Petal.Length 2.45 FALSE
## 3 2 5 6 3 Petal.Length 4.85 FALSE
## 4 3 NA NA NA <NA> NA TRUE
## 5 4 7 8 3 Petal.Length 5.40 FALSE
## 6 5 NA NA NA <NA> NA TRUE
## 7 6 NA NA NA <NA> NA TRUE
## 8 7 NA NA NA <NA> NA TRUE
## 9 8 NA NA NA <NA> NA TRUE
## prediction
## 1 <NA>
## 2 <NA>
## 3 <NA>
## 4 setosa
## 5 <NA>
## 6 virginica
## 7 virginica
## 8 versicolor
## 9 virginica
The parsed model contains one row for each path. The field, operator and split_point field list every step in a concatenated character variable.
parse_model(model)## # A tibble: 6 x 7
## labels vals type estimate field operator split_p…
## <chr> <fctr> <chr> <dbl> <chr> <chr> <chr>
## 1 path-1 setosa path 0 Petal.Length{:}… left{:}le… 2.45{:}…
## 2 path-2 virginica path 0 Petal.Length{:}… left{:}ri… 4.85{:}…
## 3 path-3 virginica path 0 Petal.Length{:}… right{:}r… 4.85{:}…
## 4 path-4 versicolor path 0 Petal.Length{:}… left{:}ri… 5.4{:}2…
## 5 path-5 virginica path 0 Petal.Length{:}… right{:}r… 5.4{:}2…
## 6 model ranger variable NA <NA> <NA> <NA>
The output from parse_model() is transformed into a dplyr, a.k.a Tidy Eval, formula. The entire decision tree becomes one dplyr::case_when() statement
tidypredict_fit(model)## case_when((((Petal.Length) < 2.45) & ((Petal.Width) < 1.75)) ~
## "setosa", (((Petal.Width) >= 1.75) & ((Petal.Length) < 4.85)) ~
## "virginica", (((Petal.Length) >= 4.85) & ((Petal.Width) >=
## 1.75)) ~ "virginica", ((((Petal.Length) >= 2.45) & ((Petal.Length) <
## 5.4)) & ((Petal.Width) < 1.75)) ~ "versicolor", ((((Petal.Length) >=
## 5.4) & ((Petal.Length) >= 2.45)) & ((Petal.Width) < 1.75)) ~
## "virginica")
From there, the Tidy Eval formula can be used anywhere where it can be operated. tidypredict provides three paths:
dplyr, mutate(iris, !! tidypredict_fit(model))tidypredict_to_column(model) to a piped command settidypredict_to_sql(model) to retrieve the SQL statementCurrently, the formula matches 146 out of 150 prediction of the test model. The threshold in tidypredict_test() is a integer indicating the number of records are OK to be different than the baseline prediction that the predict() function returns.
test <- tidypredict_test(model, iris, threshold = 4)
test## tidypredict test results
##
## Success, test is under the set threshold of: 4
## Predictions that did not match predict(): 4
test$raw_results %>%
filter(predict != tidypredict)## # A tibble: 4 x 7
## Sepal.Length Sepal.Width Petal.Length Petal.Width Species pred… tidy…
## <dbl> <dbl> <dbl> <dbl> <fctr> <chr> <chr>
## 1 5.90 3.20 4.80 1.80 versicolor vers… virg…
## 2 4.90 2.50 4.50 1.70 virginica virg… vers…
## 3 6.00 2.20 5.00 1.50 virginica virg… vers…
## 4 6.30 2.80 5.10 1.50 virginica virg… vers…