This package provides functions for fitting Horseshoe Trees, Causal Horseshoe Forests, and their more general counterparts: Shrinkage Trees and Causal Shrinkage Forests.
These models allow for global-local shrinkage priors on tree step heights, enabling adaptive modeling in high-dimensional settings.
The functions can be used for:
Supported outcome types: continuous, binary, and right-censored survival times.
The mathematical background and theoretical foundation for these models is described in the preprint Horseshoe Forests for High-Dimensional Causal Survival Analysis by T. Jacobs, W.N. van Wieringen, and S.L. van der Pas (arXiv:2507.22004).
The released version of ShrinkageTrees can be installed from CRAN:
install.packages("ShrinkageTrees")
You can install the development version from GitHub:
# Install devtools if not already installed
install.packages("devtools")
::install_github("tijn-jacobs/ShrinkageTrees") devtools
library(ShrinkageTrees)
set.seed(42)
<- 100
n <- 1000
p
# Generate covariates
<- matrix(runif(n * p), ncol = p)
X <- X_control <- X
X_treat <- rbinom(n, 1, X[, 1])
treatment <- 1 + X[, 2]/2 - X[, 3]/3 + X[, 4]/4
tau
# Generate survival times (on log-scale)
<- X[, 1] + treatment * tau + rnorm(n)
true_time <- log(rexp(n, rate = 0.05))
censor_time <- pmin(true_time, censor_time)
follow_up <- as.integer(true_time <= censor_time)
status
# Fit a standard Causal Horseshoe Forest (without propensity score adjustment)
<- CausalHorseForest(
fit_horseshoe y = follow_up,
status = status,
X_train_control = X_control,
X_train_treat = X_treat,
treatment_indicator_train = treatment,
outcome_type = "right-censored",
timescale = "log",
number_of_trees = 200,
k = 0.1,
N_post = 5000,
N_burn = 5000,
store_posterior_sample = TRUE
)
# Posterior mean CATEs
<- colMeans(fit_horseshoe$train_predictions_sample_treat)
CATE_horseshoe
# Posteriors of the ATE
<- rowMeans(fit_horseshoe$train_predictions_sample_treat)
post_ATE_horseshoe
# Posterior mean ATE
<- mean(post_ATE_horseshoe)
ATE_horseshoe
# Plot the posterior of the ATE
The package includes a demo analysis based on the TCGA PAAD (pancreatic cancer) dataset to showcase how ShrinkageTrees can be used in practice. This demo replicates the main case study from the preprint “Horseshoe Forests for High-Dimensional Causal Survival Analysis” (arXiv:2507.22004).
The demo: - Estimates propensity scores for treatment
assignment
- Fits a Causal Horseshoe Forest to the survival times with
right-censoring
- Computes the posterior mean ATE and individual CATEs with 95% credible
intervals
- Produces diagnostic plots (propensity score overlap, posterior ATE,
CATE estimates, sigma trace)
You can run it directly from R after installing the package:
demo("pdac_analysis", package = "ShrinkageTrees")
?ShrinkageTrees
, ?HorseTrees
,
?CausalHorseForest
, and ?CausalShrinkageForest
for detailed help.Contributions are welcome! Feel free to open an issue or submit a pull request. The software is designed to be flexible and modular, allowing for a wide variety of global-local shrinkage priors to be easily implemented and extended in future versions.
This package is licensed under the MIT License.
This project has received funding from the European Research Council (ERC) under the European Union’s Horizon Europe program under Grant agreement No. 101074082. Views and opinions expressed are however those of the author(s) only and do not necessarily reflect those of the European Union or the European Research Council Executive Agency. Neither the European Union nor the granting authority can be held responsible for them