Processing math: 100%

Introduction to PIE – A Partially Interpretable Model with Black-box Refinement

Tong Wang, Jingyi Yang, Yunyi Li and Boxiang Wang


Introduction to PIE

The PIE package implements Partially Interpretable Estimators (PIE), a framework that jointly train an interpretable model and a black-box model to achieve high predictive performance as well as partial model transparency.


To install the development version from GitHub, run the following:

# Install the R package from CRAN

Getting Started

This section demonstrates how to generate synthetic data for transfer learning and apply the ART framework using different models.

Generate Data

The function data_process() allows you to process dataset into the format that fits with PIE model, including cross-validation dataset (such as training, validation and testing) and group indicators for group lasso.

# Load the training data
# Which columns are numerical?
num_col <- 1:11
# Which columns are categorical?
cat_col <- 12
# Which column is the response?
y_col <- ncol(winequality)
# Data Processing
dat <- data_process(X = as.matrix(winequality[, -y_col]), 
  y = winequality[, y_col], 
  num_col = num_col, cat_col = cat_col, y_col = y_col)

Fitting PIE

Once the data is prepared, you can use the PIE_fit() function to train PIE model. In this example, we fit only with 5 iterations using group lasso and XGBoost models.

# Fit a PIE model
fold <- 1
fit <- PIE_fit(
  X = dat$spl_train_X[[fold]],
  y = dat$train_y[[fold]],
  lasso_group = dat$lasso_group,
  X_orig = dat$orig_train_X[[fold]],
  lambda1 = 0.01, lambda2 = 0.01, iter = 5, eta = 0.05, nrounds = 200

Predicting PIE

Once your PIE model is trained, you can use the PIE_predict() function to predict on test data.

# Prediction
pred <- predict(fit, 
  X = dat$spl_validation_X[[fold]],
  X_orig = dat$orig_validation_X[[fold]])

Evaluate PIE

You can evaluate your PIE model’s performance with RPE(), which has formula RPE=i(yi^yi)2i(yiˉy)2, where ˉy=1nniyi.

# Validation
val_rrmse_test <- RPE(pred$total, dat$validation_y[[fold]])