| Title: | Multimodal Late Fusion with 'caret' |
| Version: | 1.0.0 |
| Description: | Extends the 'caret' framework to support late fusion workflows, enabling users to train models independently across multiple data modalities and combine their predictions into a single meta-model. Designed for developers, data scientists, and biomedical researchers alike, 'caretMultimodal' aims to make late fusion ensemble modelling as accessible and flexible as single-dataset workflows in 'caret'. Late fusion methods are based on Wolpert (1992) <doi:10.1016/S0893-6080(05)80023-1>. |
| License: | MIT + file LICENSE |
| Encoding: | UTF-8 |
| RoxygenNote: | 7.3.3 |
| Suggests: | testthat (≥ 3.0.0), randomForest, doParallel |
| Imports: | caret, data.table, ggplot2, pROC, foreach, viridis, MultiAssayExperiment, glmnet |
| Config/testthat/edition: | 3 |
| URL: | https://github.com/CompBio-Lab/caretMultimodal, https://compbio-lab.github.io/caretMultimodal/ |
| BugReports: | https://github.com/CompBio-Lab/caretMultimodal/issues |
| Depends: | R (≥ 3.5) |
| LazyData: | true |
| NeedsCompilation: | no |
| Packaged: | 2026-06-24 16:21:44 UTC; jpdyc |
| Author: | Josh Dyce [aut, cre], Amrit Singh [aut] |
| Maintainer: | Josh Dyce <jpdyce@gmail.com> |
| Repository: | CRAN |
| Date/Publication: | 2026-06-30 12:30:07 UTC |
Construct a caret_list object
Description
Builds a list of caret::train objects, where each model corresponds to a data set in data_list.
The resulting list is used as input to caret_stack() to construct a meta model.
Usage
caret_list(
target,
data_list,
method,
identifier_column_name = NULL,
trControl = NULL,
metric = NULL,
trim = TRUE,
do_parallel = TRUE,
...
)
Arguments
target |
Target vector, either numeric for regression or a factor/character for classification. |
data_list |
A named list of matrix-like objects, where each element is a dataset to train a separate model on.
Names are preserved in the returned |
method |
The method to train the models with. Can be a custom method or one found in |
identifier_column_name |
A string giving the name of a column that links rows across datasets (e.g. a participant ID).
If provided, this column must be present in all datasets in Note: Providing |
trControl |
Control for use with the |
metric |
Metric for use with |
trim |
Logical, whether the trained models should be trimmed to save memory. Default is |
do_parallel |
Logical, whether to parallelize model training across datasets. Default is |
... |
Additional arguments to pass to the |
Value
A caret_list object (a named list of trained caret::train models corresponding to data_list).
Examples
set.seed(42)
data(heart_failure_datasets)
data_list <- heart_failure_datasets[c("cells", "holter", "mrna", "proteins")]
# Define hyperparameters to tune (optional)
tuneGrid <- expand.grid(alpha = 0.5, lambda = c(0.01, 0.1))
# Construct caret_list object
base_models <- caret_list(
target = heart_failure_datasets$demo$hospitalizations,
data_list = data_list,
method = "glmnet",
tuneGrid = tuneGrid
)
class(base_models)
Construct a caret_stack object.
Description
Train an ensemble (stacked) model from the base learners in a
caret_list. The ensemble is itself a caret::train model that learns to
combine the predictions of the base models. By default, the meta-learner is
trained on out-of-fold predictions from the resampling process, ensuring that
the ensemble does not overfit to in-sample predictions. Alternatively, new
datasets can be supplied via data_list and target for transfer-learning
style ensembling.
Usage
caret_stack(
caret_list,
method,
data_list = NULL,
target = NULL,
trControl = NULL,
metric = NULL,
...
)
Arguments
caret_list |
a |
method |
The method to train the ensemble model. Can be a custom method or one found in |
data_list |
A list of datasets to predict on, with each dataset matching the corresponding model in |
target |
Target parameter vector that must be provided if predicting on a new data list.
If |
trControl |
Control for use with the |
metric |
Metric for use with |
... |
Additional arguments to pass to |
Value
A caret_stack object.
Examples
set.seed(42)
data(heart_failure_datasets)
data_list <- heart_failure_datasets[c("cells", "holter", "mrna", "proteins")]
# Define hyperparameters to tune (optional)
tuneGrid <- expand.grid(alpha = 0.5, lambda = c(0.01, 0.1))
# Construct caret_list object
base_models <- caret_list(
target = heart_failure_datasets$demo$hospitalizations,
data_list = data_list,
method = "glmnet",
tuneGrid = tuneGrid
)
# Train a Random Forest stacked model on the out-of-fold predictions from the base models
stacked_model <- caret_stack(
caret_list = base_models,
method = "rf"
)
class(stacked_model)
Conduct an ablation analysis with a caret_stack object
Description
Conduct an ablation analysis with a caret_stack object
Usage
compute_ablation(object, ...)
Arguments
object |
A |
... |
Additional arguments passed to class-specific methods |
Value
A data.frame
Perform an ablation analysis for a caret_stack
Description
This function performs an ablation analysis on a caret_stack ensemble to evaluate
each base learner's contribution to predictive performance.
Starting from the full ensemble, the procedure iteratively removes one base learner per step. At each step:
The ensemble meta-learner is retrained on the remaining base learners, using the same
method,tuneGrid, andtrControlas the original stack.Variable importance scores are extracted from the retrained meta-learner to estimate each remaining learner's relative contribution.
Out-of-fold predictions are generated and scored with
metric_function.The learner with the lowest importance score (or highest, if
reverse = TRUE) is removed before the next iteration.
Usage
## S3 method for class 'caret_stack'
compute_ablation(object, metric_function, metric_name, reverse = FALSE, ...)
Arguments
object |
A |
metric_function |
A function that takes two arguments |
metric_name |
The name of the metric. Used as a row label in the returned |
reverse |
Logical, controls the direction to ablate in. If |
... |
Not used. Included for S3 compatibility. |
Value
A data.table
Note
This function does not support for multiclass classifiers.
Examples
# Load pre-trained example caret_stack object
data(heart_failure_stack)
# Since the example stack is a binary classifier,
# this metric function needs to take in predictions (floats) and
# ground truth (binary vector), and produce a single number.
metric_fun <- function(preds, target) {
pROC::roc(response = target, predictor = preds, quiet = TRUE)$auc
}
compute_ablation(heart_failure_stack, metric_fun, "AUC")
Compute feature level contributions for a caret_stack object
Description
Compute feature level contributions for a caret_stack object
Usage
compute_feature_contributions(object, ...)
Arguments
object |
A |
... |
Additional arguments passed to class-specific methods |
Value
A data.frame
Compute the feature level contributions for a caret_stack.
Description
Computes the contribution of each individual feature to the ensemble's
predictions using a two-stage application of caret::varImp:
-
Dataset-level weights:
varImpis applied to the ensemble meta-learner, treating each base model's predictions as a feature. This yields a relative importance weight for each dataset. -
Feature-level importance:
varImpis applied to each base model individually, yielding feature importance scores within each dataset.
The final contribution of a feature is the product of its dataset-level weight and its within-dataset feature importance score. All scores are normalized to sum to 100.
Usage
## S3 method for class 'caret_stack'
compute_feature_contributions(object, n_features = 20, ...)
Arguments
object |
A |
n_features |
The maximum number of features to include. Setting to a very large value will include all features. Default is 20. |
... |
Not used. Included for S3 compatibility. |
Value
A data.table
Examples
# Load pre-trained example caret_stack object
data(heart_failure_stack)
compute_feature_contributions(heart_failure_stack)
Compute metrics for a caret_stack object
Description
Compute metrics for a caret_stack object
Usage
compute_metric(object, ...)
Arguments
object |
A |
... |
Additional arguments passed to class-specific methods |
Value
A data.frame
Compute metrics with a provided metric function
Description
The metric_function is applied to the out-of-fold predictions for the caret_stack.
Usage
## S3 method for class 'caret_stack'
compute_metric(object, metric_function, metric_name, descending = TRUE, ...)
Arguments
object |
A |
metric_function |
A function that takes two arguments |
metric_name |
The name of the metric |
descending |
Whether to sort in descending order. If |
... |
Not used. Included for S3 compatibility. |
Value
A data.table of metrics
Examples
# Load pre-trained example caret_stack object
data(heart_failure_stack)
# Since the example stack is a binary classifier,
# this metric function needs to take in predictions (floats) and
# ground truth (binary vector), and produce a single number.
metric_fun <- function(preds, target) {
pROC::roc(response = target, predictor = preds, quiet = TRUE)$auc
}
compute_metric(heart_failure_stack, metric_fun, "AUC")
Compute base model contributions for a caret_stack object
Description
Compute base model contributions for a caret_stack object
Usage
compute_model_contributions(object, ...)
Arguments
object |
A |
... |
Additional arguments passed to class-specific methods |
Value
A data.frame
Compute the relative contributions of each of the base models in the ensemble model
Description
The relative contributions are calculated using the caret::varImp function on the ensemble model.
A scaling factor is applied to make the contributions sum to 100%.
Usage
## S3 method for class 'caret_stack'
compute_model_contributions(object, descending = TRUE, ...)
Arguments
object |
A |
descending |
Whether to sort in descending order. If |
... |
Not used. Included for S3 compatibility. |
Value
A data.table
Examples
# Load pre-trained example caret_stack object
data(heart_failure_stack)
compute_model_contributions(heart_failure_stack)
Heart Failure Datasets
Description
A multimodal dataset from Singh et al. (2019) containing demographic, cellular, electrophysiological, and molecular features for predicting cardiac-related hospitalizations. Used in examples throughout the caretMultimodal package.
Usage
heart_failure_datasets
Format
A named list with 5 elements:
- demo
A
data.frameof demographic features- cells
A
data.frameof cell count features- holter
A
data.frameof Holter monitor (ECG) features- mrna
A
data.frameof mRNA expression features- proteins
A
data.frameof protein abundance features
Source
Singh et al. Ensembling Electrical and Proteogenomics Biomarkers for Improved Prediction of Cardiac-Related 3-Month Hospitalizations: A Pilot Study. Can J Cardiol. 2019 Apr. doi:10.1016/j.cjca.2018.11.021
Pre-trained caret_stack on Heart Failure Datasets
Description
A caret_stack object pre-trained on heart_failure_datasets.
Used in examples throughout the caretMultimodal package.
Usage
heart_failure_stack
Format
A caret_stack object
See Also
Extract out-of-fold predictions from caret_list or caret_stack objects
Description
Extract out-of-fold predictions from caret_list or caret_stack objects
Usage
oof_predictions(object, ...)
Arguments
object |
A |
... |
Additional arguments passed to class-specific methods |
Value
A data.frame
Out-of-fold predictions from a caret_list
Description
Retrieve the out-of-fold predictions corresponding to the best hyperparameter setting of the trained caret models. These predictions come from the resampling process (not the final refit) and can optionally be aggregated across resamples to produce a single prediction per training instance.
Usage
## S3 method for class 'caret_list'
oof_predictions(
object,
drop_redundant_class = TRUE,
aggregate_resamples = TRUE,
intersection_only = TRUE,
...
)
Arguments
object |
A |
drop_redundant_class |
Logical, whether to exclude the first class level from prediction output. Default is |
aggregate_resamples |
Logical, whether to aggregate resamples across folds. Default is |
intersection_only |
Logical, whether to trim down the out-of-fold predictions to only the intersection of
samples that are present across all models in the list (i.e., the intersection of training indices used during resampling).
Default is |
... |
Not used. Included for S3 compatibility. |
Value
A data.table::data.table of out-of-fold predictions, with samples as rows and predictions as columns.
Examples
# Load pre-trained example caret_stack object
data(heart_failure_stack)
# Extract the caret_list object from the caret_stack
base_models <- heart_failure_stack$caret_list
oof_predictions(base_models)
Out-of-fold predictions from a caret_stack
Description
Retrieve the out-of-fold predictions corresponding to the best hyperparameter setting of a trained ensemble model. These predictions come from the resampling process (not the final refit) and can optionally be aggregated across resamples to produce a single prediction per training instance.
The base model predictions returned here are the training data for the ensemble; depending on model setup, these may be true out-of-fold predictions or simply fitted values. For classification models, the predictions always exclude the first class index.
Usage
## S3 method for class 'caret_stack'
oof_predictions(
object,
drop_redundant_class = TRUE,
aggregate_resamples = TRUE,
...
)
Arguments
object |
A |
drop_redundant_class |
A boolean controlling whether to exclude the first class level from prediction output. Default is |
aggregate_resamples |
Logical, whether to aggregate resamples across folds. Default is |
... |
Not used. Included for S3 compatibility. |
Value
A data.table::data.table of OOF predictions
Examples
# Load pre-trained example caret_stack object
data(heart_failure_stack)
oof_predictions(heart_failure_stack)
Plot an ablation analysis with a caret_stack object
Description
Plot an ablation analysis with a caret_stack object
Usage
plot_ablation(object, ...)
Arguments
object |
A |
... |
Additional arguments passed to class-specific methods |
Value
A ggplot2 object
Make a bar plot of an ablation analysis for a caret_stack.
Description
Makes a bar plot from compute_ablation.caret_stack output.
Usage
## S3 method for class 'caret_stack'
plot_ablation(object, metric_function, metric_name, reverse = FALSE, ...)
Arguments
object |
A |
metric_function |
A function that takes two arguments |
metric_name |
The name of the metric. Used as a row label in the returned |
reverse |
Logical, controls the direction to ablate in. If |
... |
Not used. Included for S3 compatibility. |
Value
A ggplot2 bar plot
Examples
# Load pre-trained example caret_stack object
data(heart_failure_stack)
# Since the example stack is a binary classifier,
# this metric function needs to take in predictions (floats) and
# ground truth (binary vector), and produce a single number.
metric_fun <- function(preds, target) {
pROC::roc(response = target, predictor = preds, quiet = TRUE)$auc
}
plot_ablation(heart_failure_stack, metric_fun, "AUC")
Plot feature level contributions for a caret_stack object
Description
Plot feature level contributions for a caret_stack object
Usage
plot_feature_contributions(object, ...)
Arguments
object |
A |
... |
Additional arguments passed to class-specific methods |
Value
A ggplot2 object
Make a bar plot of feature level for a caret_stack.
Description
Constructs a bar plot with the output of compute_feature_contributions.caret_stack.
Usage
## S3 method for class 'caret_stack'
plot_feature_contributions(object, n_features = 20, ...)
Arguments
object |
A |
n_features |
The maximum number of features to include. Setting to a very large value will include all features. Default is 20. |
... |
Not used. Included for S3 compatibility. |
Value
A ggplot2 bar plot.
Examples
# Load pre-trained example caret_stack object
data(heart_failure_stack)
plot_feature_contributions(heart_failure_stack)
Plot metrics for a caret_stack object
Description
Plot metrics for a caret_stack object
Usage
plot_metric(object, ...)
Arguments
object |
A |
... |
Additional arguments passed to class-specific methods |
Value
A ggplot2 object
Plot metrics computed with a provided metric function
Description
This function constructs a bar plot with the output of the compute metric method. The bars are ordered by increasing value.
Usage
## S3 method for class 'caret_stack'
plot_metric(object, metric_function, metric_name, descending = TRUE, ...)
Arguments
object |
A |
metric_function |
A function that takes two arguments |
metric_name |
The name of the metric |
descending |
Whether to sort in descending order. If |
... |
Not used. Included for S3 compatibility. |
Value
A ggplot2 bar chart
Examples
# Load pre-trained example caret_stack object
data(heart_failure_stack)
# Since the example stack is a binary classifier,
# this metric function needs to take in predictions (floats) and
# ground truth (binary vector), and produce a single number.
metric_fun <- function(preds, target) {
pROC::roc(response = target, predictor = preds, quiet = TRUE)$auc
}
plot_metric(heart_failure_stack, metric_fun, "AUC")
Plot base model contributions for a caret_stack object
Description
Plot base model contributions for a caret_stack object
Usage
plot_model_contributions(object, ...)
Arguments
object |
A |
... |
Additional arguments passed to class-specific methods |
Value
A ggplot2 object
Plot the relative contributions of each of the base models in the ensemble model
Description
The relative contributions are calculated using the caret::varImp function on the ensemble model.
A scaling factor is applied to make the contributions sum to 100%.
Usage
## S3 method for class 'caret_stack'
plot_model_contributions(object, descending = TRUE, ...)
Arguments
object |
A |
descending |
Whether to sort in descending order. If |
... |
Not used. Included for S3 compatibility. |
Value
A ggplot2 bar chart
Examples
# Load pre-trained example caret_stack object
data(heart_failure_stack)
plot_model_contributions(heart_failure_stack)
Make an ROC plot for a caret_stack object
Description
Make an ROC plot for a caret_stack object
Usage
plot_roc(object, ...)
Arguments
object |
A |
... |
Additional arguments passed to class-specific methods |
Value
A ggplot2 object
Plot ROC curves for individual and ensemble models in a caret_stack
Description
This function calculates ROC curves for all base models and the ensemble model
using the out-of-fold predictions from a caret_stack object.
The pROC package is used to compute the ROC curves. ROC curves can only be constructed for binary classifiers.
Usage
## S3 method for class 'caret_stack'
plot_roc(object, include_auc = TRUE, ...)
Arguments
object |
A |
include_auc |
Whether to include AUC values in the legend. Default is |
... |
Not used. Included for S3 compatibility. |
Value
A ggplot2 object
Examples
# Load pre-trained example caret_stack object
data(heart_failure_stack)
plot_roc(heart_failure_stack)
Predict from a caret_list
Description
Generate a matrix of predictions from each model in a caret_list.
For classification models, probabilities are always returned, with the option to drop
one class to avoid multicollinearity in downstream stacking models.
Usage
## S3 method for class 'caret_list'
predict(object, data_list, drop_redundant_class = TRUE, ...)
Arguments
object |
A |
data_list |
A list of datasets to predict on, with each dataset matching the corresponding model in |
drop_redundant_class |
Logical, whether to exclude the first class level from prediction output. Default is |
... |
Additional arguments to pass to |
Value
A data.table::data.table of predictions
Examples
# Load example data and pre-trained caret_stack object
data(heart_failure_datasets)
data(heart_failure_stack)
# Extract the caret_list object from the caret_stack
base_models <- heart_failure_stack$caret_list
# List of datasets to predict on
data_list <- heart_failure_datasets[c("cells", "holter", "mrna", "proteins")]
predict(base_models, data_list)
Create a matrix of predictions for a caret_stack object.
Description
Create a matrix of predictions for a caret_stack object.
Usage
## S3 method for class 'caret_stack'
predict(object, data_list, drop_redundant_class = TRUE, ...)
Arguments
object |
A |
data_list |
A list of datasets to predict on, with each dataset matching the corresponding model in |
drop_redundant_class |
A boolean controlling whether to exclude the first class from prediction output. Default is |
... |
Additional arguments to pass to |
Value
A data.table::data.table of predictions for base and ensemble models.
Examples
# Load example data and pre-trained caret_stack object
data(heart_failure_datasets)
data(heart_failure_stack)
# List of datasets to predict on
data_list <- heart_failure_datasets[c("cells", "holter", "mrna", "proteins")]
predict(heart_failure_stack, data_list)
Prepare a MultiAssayExperiment for use with caretMultimodal
Description
Converts a MultiAssayExperiment object from the MultiAssayExperiment package to a
simple list of datasets to pass into caret_list.
Usage
prepare_mae(mae, transpose = FALSE, ...)
Arguments
mae |
The MultiAssayExperiment object. |
transpose |
Whether to transpose the individual matrices. Samples must correspond to rows for caret_list. Default is FALSE. |
... |
Not used. Included for S3 compatibility. |
Value
A named list of matrices.
Provide a summary of the best tuning parameters and resampling metrics for all the caret_list models.
Description
Provide a summary of the best tuning parameters and resampling metrics for all the caret_list models.
Usage
## S3 method for class 'caret_list'
summary(object, ...)
Arguments
object |
a |
... |
Not used. Included for S3 compatibility. |
Value
A data.table with tunes and metrics from each model.
Examples
# Load pre-trained example caret_stack object
data(heart_failure_stack)
# Extract the caret_list object from the caret_stack
base_models <- heart_failure_stack$caret_list
summary(base_models)
Get a summary of a caret_stack object
Description
Get a summary of a caret_stack object
Usage
## S3 method for class 'caret_stack'
summary(object, ...)
Arguments
object |
A |
... |
Not used. Included for S3 compatibility. |
Value
A data.table of methods, tuning parameters and performance metrics for the base and ensemble model
Examples
# Load pre-trained example caret_stack object
data(heart_failure_stack)
summary(heart_failure_stack)