An introduction to Bayesian additive regression trees

Demetrios Papakostas

January 2025

Notes

  • Theme borrowed from Emi Tanaka

  • Also used elements from this excellent quarto tutorial

  • Notes and discussion from Andrew Herren, Rafael Alcantara, Rob McCulloch, Hedibert Lopes, and Richard Hahn provided excellent insights for these notes.

Introduction

  • BART (Hugh A. and Chipman, George, and McCulloch 2012) is a prior over functions. The prior is a sum of decision trees, aka a forest.

  • BART is a Gaussian process, with a kernel learned from the data.

  • Anything random forest (or XGBoost) can do, BART can (usually) do better, particularly learning complex functions in noisy situations.

  • BART’s Bayesian framework invites custom models. stochtree enables such modifications without having to edit C++.

More on stochtree

  • Implements BART and other stochastic tree variants. Fast, easy to use, and easy to customize.

  • Implements many bells and whistles.

Hold up, let’s do a recap on trees

  • Decision trees can be represented mathematically:

    \[ f(\mathbf{x})=\sum_{j=1}^{\text{# partitions}} \mu_{j}\cdot\mathbf{1}\left(\mathbf{x}\in \mathcal{S}_j\right) \]

Which says that \(\mathbf{x}\) is partitioned into sections, and the mean in that section is the value we care about.

Visually

A sample of a decision tree

Another visual

Excalidraw

Excalidraw

Why are trees useful?

  • I wanna predict how well I will sleep.

  • Expect some interactions between my daily habits with respect to my next nights sleep. If its warm out and I slept well the night before, I tend to sleep well the next night. This alone may not be enough…

  • However, don’t want to account for too many if else statements, or else we’ll overfit the idiosyncracies in my day.

The B is for Bayesian

  • As good Bayesians, we know that the Bayesian recipe entails a prior, a likelihood, and a posterior.
  • Specifying a prior over decision tree space is tricky, sampling from the posterior of those trees is even harder.
  • Luckily, (Hugh A. Chipman, George, and McCulloch 1998) figured out a pretty good way to accomplish these tasks for a single tree.
  • (Hugh A. and Chipman, George, and McCulloch 2012) improved the approach with the sum of trees, a fantastic model that should be the backbone of modern data science.

The BART model

\[ \begin{align} y_i &= \sum_{j=1}^{M}f(\mathbf{x}, T_j, \Theta_j)+\varepsilon, \qquad \varepsilon\sim N(0, \sigma^2)\\ E(y\mid \mathbf{x}) &= \text{signal} =\sum_{j=1}^{M}f(\mathbf{x}, T_j, \Theta_j)\\ \text{noise } = \sigma^2 &\sim \text{IG}\left(\frac{\nu}{2},\frac{\nu*\lambda}{2}\right) \end{align} \]

  • \(T_j\) are decision trees, \(\mathcal{T}\) represents the set of \(M\) trees, and \(\Theta_j\) is the set of leaf outcome parameters \(\mu_{qj}\) (for leaves \(q=1, \ldots, \text{# partitions}\) in tree \(j\)).
  • \(\sigma^2\) models unexplained variance/noise.

So what are we doing (single tree version)?

  • Model \(N_{\text{MCMC}}\) iterations representing draws from probability distributions for \(\hat{f}\).

  • Uncertainty comes into play in a couple ways: for each iteration, trees will look a little different.

  • Each tree splits the data and the predicted outcome at is randomly drawn centered at the mean in that partition.

  • The variance in the error term (whatever isn’t accounted for by the trees) is another random draw per iteration of the algorithm.

Where are the sum of trees introduced?

  • So far, we have mostly focused on how to build a Bayesian tree.

  • Borrowing ideas from boosting, can sum trees by fitting subsequent trees to the residual, \(r = y-f(\mathbf{x})\), of the previous trees.

Prior specifications (fairly technical)

A bit more on the residual variance prior

  • Recall, \(\varepsilon\sim N(0,\sigma^2)\). This is the “model” uncertainty, aka remaining uncertainty besides what we can explain \(f(\mathbf{x})\). Mean zero implies we assume \(E(y\mid\mathbf{x})=f(\mathbf{x})\longrightarrow y\sim N(f(\mathbf{x}), \sigma^2)\).

  • Prior on error term variance \(\sigma^2\): calibrated such that 90% chance that \(\sigma^2\) will be less than \(\hat{\sigma}^2\) estimated by a linear regression fit on the data.

  • Later, we will see how to model \(\sigma^2\) to be a function of \(\mathbf{x}\).

Likelihood & leaf prior specifications

  • Constant predicted outcome in each leaf of each tree is drawn from \(\mu_{qj}\sim N(0, \tau)\), which is conjugate with the normal outcome likelihood! Centered at \(0\), shift \(y\) to be between \(-0.5\) and \(0.5\).

  • So, \(\tau = \frac{0.5}{k\sqrt{m}}\), where \(m\) is the number of trees. Choose \(k\) such that there is a \(\Phi^{-1}(k)\)% chance the tree sums will be in that range. The set of leaf parameters for a tree \(j\) is described by the set \(\Theta_{j}\).

  • Has a shrinking effect; helps ensure small contributions from components of the tree. The conjugacy here will come in handy later.

Tree prior

  • We also need a prior for growing the tree.

\[ \Pr(\text{node keeps growing at depth $d$}) = \frac{\alpha}{(1+d)^{\beta}} \]

Regularizes each tree to not split too much, i.e. a subtle push to “additivity” in a sense.

  • Also include a prior on which variable to split on, and where along that variables range to split. Default to uniform distribution for the variable, and uniform along a discrete set of points in the range of the variable.

A cool note on the tree prior

  • The tree prior is a “dilution prior” (George 1999).

  • Dilutes probability within similar clusters of trees.

  • This places itself out in the posterior. If two trees fit equally well, will place more probability on more “unique” tree.

  • Ex) two trees splitting on two very correlated variables will have \(\approx\) equal prior weight, lowering each of their probabilities.

One more thing

  • Choose priors such that the tree components are independent of each other and of \(\sigma^2\), the variance parameter in \(\varepsilon\).

  • That is:

    \[ \begin{align} \pi\left((T_1,\Theta_{1}), \ldots, (T_m,\Theta_m),\sigma\right)&=\left[\prod_{j=1}^{m}\pi(\Theta_j, T_j)\right]\pi(\sigma^2)\\ &=\left[\prod_{j=1}^{m}\pi(\Theta_j\mid T_j)\pi(T_j)\right]\pi(\sigma^2)\\ \pi(\Theta_j\mid T_j)&=\prod_{q}^{\text{# partitions in tree $j$}}\pi(\mu_{qj}\mid T_j) \end{align} \]

  • We have now specified the priors: \(\pi(T_j)\), \(\pi(\mu_{qj}\mid T_j)\), and \(\pi(\sigma^2)\).

Not always the case

  • For example, if we want to build “monotone” trees (Hugh A. Chipman et al. 2022), one way to do so is ensuring every leaf node would be greater/less than the previous. Now the leaf nodes must be dependent!

  • Without independence, harder to sample the posterior. (Hugh A. Chipman et al. 2022) use a grid to evaluate the resulting posterior.

What the tree growth prior implies

Depth Probability of node being terminal
1 0.05
2 0.55
3 0.28
4 0.09
\(\geq\) 5 0.03

Posterior computations (more technical)

Spange

  • Want to sample from (a single, for now) tree posterior, \(\Pr(T_{j}\mid y, \sigma^2) \propto f(y\mid T_{j}, \sigma^2)\pi(T_{j})\). Of course, we are actually modeling \(\Pr((T_{j},\Theta_{j})\mid y, \sigma^2)\), but because of conjugacy of prior and likelihood on the leaf parameters \(\mu_{qj}\in \Theta_{j}\) within a tree \(j\), we can “average” or integrate out \(\Theta_{j}\):

    \(f(y\mid T_{j}) = \int f(y\mid \Theta_{j}, T_{j}, \sigma^2)\pi(\Theta_{j}\mid T_{j}, \sigma^2)\text{d}\Theta_{j}\),

which conveniently has closed form solution (thanks conjugacy!).1

Posterior of tree space

  • We now have the ingredients to describe the form of the posterior of the trees given the observed data.

\[ \Pr(T\mid y)=\frac{\Pr(y\mid T)\pi(T)}{\int_{\text{tree space}}\Pr(y\mid T)\pi(T)\text{d}\text{(trees)}} \]

  • Cannot possibly evaluate every possible tree. MCMC devised to stochastically explore the tree space.

Random walk MCMC (grow/prune)

  • Accept/reject proposed trees based on evaluation of product of \(f(y\mid T)\pi(T)\) and the transition probability of new tree versus the previous tree.

  • To approximate the correct distribution (eventual convergence) for \(\Pr(T\mid y)\), need an irreducible, aperiodic, and recurrent Markov chain.

  • So need to satisfy these properties and show empirical evidence of reasonably quick convergence.

  • Want to build reversible trees to guarantee the Markov chain will be irreducible (not always true for any Markov chain).

Random walk MCMC: Part II

  • MCMC algorithm of (Hugh A. Chipman, George, and McCulloch 1998) is concise algorithm and works pretty well in practice.

    1) Grow: Randomly pick a terminal node and split on a new variable at random, creating a new split rule and two new leaf nodes.(50% chance)

    2) Prune: Randomly choose a parent of two leaf nodes and remove its leaves, making that parent a terminal/leaf node.(50% chance)

How to evaluate

  • Evaluate the new proposed trees based on:

    \[ \alpha(T^*\mid T^k)=\min\left(\frac{Q(T^k\mid T^*)g(T^*)}{Q(T^*\mid T^k)g(T^k)},1\right) \]

  • Where the transition kernel, \(q()\) is defined by the moves on the previous slide. \(^*\) signifies the proposed new tree and the \(^k\) index represents the previous MCMC iteration.\(\frac{g(T^*)}{g(T^k)}=\frac{\Pr(T^*\mid y)}{\Pr(T^k\mid y)}=\frac{f(y\mid T^*)\pi(T^*)f(y)}{f(y\mid T^k)\pi(T^k)f(y)}\) and the marginal likelihood \(f(y)\) cancels out.

  • \(Q(T^k \mid T^*)\) is probability of proposing tree \(T^k\) given current state (death) and \(Q(T^* \mid T^k)\) is the probability of proposing tree \(T^*\) given the current tree structure \(T^k\) (a birth).

Transition ratio

  • For the “birth move”, the probability of splitting off of one of the current \(n_\text{bottom}\) nodes is:

    \[ Q(T^k\mid T^*) = \frac{1}{n_{\text{bottom}}}\Pr(\text{birth})\Pr(\text{selecting variable})\Pr(\text{splitting along variable}) \]

  • Where \(\Pr(\text{birth})=0.5\) (usually), and the other two variables are usually uniform

A familiar foe has returned: the likelihood ratio

  • The ratio: \(\frac{g(T')}{g(T)}=\frac{f(y\mid T')\pi(T')}{f(y\mid T)\pi(T)}\) plays a critical role. The prior for the tree was proposed earlier and the marginal likelihood \(f(y\mid T)\) is analytically available for both the “left” and “right” splits of a tree.

  • Since we assume independence across leaf nodes, marginal likelihood is just the product between leaves1.

  • See (Sparapani, Spanbauer, and McCulloch 2021) for more detail, or appendix C.1.2 of (Tan and Roy 2019).

Tree prior ratio

  • Finally,

    \[ \frac{\pi(T^*)}{\pi(T^k)} = \frac{\left[1-\Pr(\text{tree grows left}\right]\cdot\left[1-\Pr(\text{tree grows right})\right]\cdot \Pr(\text{tree grows})\Pr(\text{which})\Pr(\text{where})}{1-\Pr(\text{tree grows})} \]

  • \(\Pr(\text{grows})\) corresponds to \(\Pr(\text{node keeps growing at depth $d$}) = \frac{\alpha}{(1+d)^{\beta}}\). Which and where are shorthand for \(\Pr(\text{selecting variable})\) and \(\Pr(\text{splitting along variable})\) respectively.

Bayesian backfitting

  • A single tree may take a really long time to mix. Can get stuck at a large tree.

  • Fitting many trees helps mix better. Since each tree contributes a small part to the fit, with enough trees can find a reasonable \(f(\mathbf{x}_i)\) sooner.

Full MCMC procedure

  • Sample a tree with the Metropolis Hastings birth death with the residual acting as the outcome (the first tree just uses the outcome).

  • Sample leaf node parameters via Gibbs. The posterior is also normal thanks to the conjugacy, and the residual still plays the role of the outcome.

  • Calculate the residual from this tree, move onto to building the next tree with this residual.

  • After going through all the trees, sample \(\sigma^2\sim \text{Inverse Gamma}\left(\frac{\nu}{2},\frac{\nu\cdot \lambda}{2}\right)\). Move onto next MCMC iteration of the sampler.

Full MCMC procedure: more formally

For trees \(j\in1, 2, \ldots, M\) and with \(-\) indicating conditioning on the other trees, \(\mathcal{T}\) is the set of all the trees, and \(\mathbf{\Theta}\) is the collection of all the \(\Theta_{j}\)’s, the set of leaf parameters for tree \(j\):

  1. Sample \(T_j, \mu_{j}\mid \mathcal{T}_{-j}, \mathbf{\Theta}_{-j}, \sigma^2, y\).

    Which because of the conjugacy of the outcome and the priors on the leaf mean parameters \(\mu_{qj}\) can be done as:

    \(T_j\mid \mathcal{T}_{-j}, \mathbf{\Theta}_{-j}, \sigma^2, y\) and \(\mu_{j}\mid \mathcal{T}, \Theta_{-j}, \sigma^2, y\).

  2. Sample \(\sigma^2\mid \mathcal{T}, \Theta, y\).

An alternative and simpler sampler

Since \(\Pr(T_j, \mu_{j}\mid \mathcal{T}_{-j}, \mathcal{\Theta}_{-j}, \sigma^2, y)\) depends on \((T_{-j}, \Theta_{-j}, y)\) only through the residual, we can write these steps conditional on the current residuals \(r_j\):

  1. Sample \(T_j, \mu_{qj}\mid r_{j}, \sigma^2, y\) as \(T_j \mid r_j, \sigma^2\) (which is drawn from the birth/death MCMC process) and \(\mu_{qj}\mid T_{j}, r_{j}\) (whose posterior is a draw from a normal distribution). Once all the \(\mu_{qj}\in \Theta_{j}\) are calculated, we can update the residuals and then the next tree can be calculated.

  2. After the \(m\) trees are sampled, sample \(\sigma^2\mid r\) , which is an Inverse Gamma posterior.

Repeat this procedure for \(N_{\text{MCMC}}\) draws. Usually want to “burn in” some until the sampler finds high probability trees… hard to know in practice when this is the case.

The TL;DR

  • Posterior for \(y\mid\mathbf{x}\) is proportional to \(\text{likelihood}\times \text{prior}\).

  • Prior to seeing data, expect small contributions from each tree.

  • Likelihood of tree describes how well proposed tree “fits” the data.

  • So… trees can grow deeper if data permit.

Mixing problems

  • The BART MCMC algorithm proposes one move per tree at a time and can take a while to mix.

  • Worryingly, can sometimes get stuck in poor fitting regions of the tree posterior space.

  • Ideally, the draws are random samples from the posterior and should not show any autocorrelation draw to draw, but BART often does not mix well enough.

  • (Ronen et al. 2022), albeit for a very simplified version of BART, show degradation of mixing scaling with \(n\).

How to improve mixing

Motivations

  • Ideally starting the posterior at high probability trees and/or averaging over many independent chains of forests may be helpful.

  • With many redundant features, it is more likely for BART to be doomed by poor first split in tree.

  • Therefore, we’d want to help by providing variables with strong signal to split on (more to come).

XBART: Build the trees anew every iteration!

  • (He and Hahn 2023) combines recursive partitioning (repeatedly calling the same splitting process) from CART (Breiman et al. 1984) with stochastic growing rules instead of greedy optimization.

  • From Bayes rule, can write, where \(s\) represents where to split, and \(v\) represents which variable to split on

    \[\Pr(s,v) = \frac{f(s,v)\pi(s)}{\sum_{s_j}\sum_{v_j}f(s_j,v_j)\pi(s_j)}\]

  • \(f(s,v)\): marginal likelihood of tree splits, available in analytical form. No split corresponds to tree prior “done splitting” probability.

XBART 2:

  • Build tree fully for every tree in forest. Still pass residuals to next tree.

  • Restart entire forest (remake every tree) for each iteration.

  • Pass the variance parameter at the end of every iteration.

But wait… there’s more

  • Other computational tricks incorporated to make tree building faster.

  • Each XBART iteration is slower than the BART MCMC one at a time forest iteration. BUT… mixing is considerably faster. BART usually needs thousands of MCMC draws, XBART merely 5-20 iterations!

  • Amounts to 20-30x speedup and capability to handle larger datasets.

  • However, XBART is not a valid BART posterior (trees aren’t built reversibly).

A clever fix

  • Warmstart BART: initialize forests for MCMC sampler with the trees from XBART forests then do ~ 100 MCMC draws.
  • Can either start the sampler from the last XBART iteration or run \(m\) chains with \(m\) different XBART forests.
  • If the well initiliazed BART MCMC samplers are mixing well, each chain represents an independent sequence of draws from the posterior.
  • Therefore, the combined chains still give the BART posterior.

Advantages

  • Base XBART undercovers, warmstart outperforms BART and XBART. Easily parallelalizable, so can run multiple chains on modern processors to improve prediction and coverage!
  • Warmstart improves BART coverage and facilitates faster mixing… best of both worlds!
  • Trees don’t need to be built “Bayesianly”, can be a “data informed prior”. XBART does benefit from using marginal likelihood and sampling other parameters though.

Some other XBART changes

  • The prior on the value drawn in the leaves has a prior on it now!

    \[ \mu_{qj}\sim N(0,\tau) \; ; \; \tau\sim \text{IG}(a,b) \]

  • Rather than choosing one uniformly at random in BART, or considering all variables as Bayes rule approach above, update variable weights based on (Linero 2018). Dirichlet prior on which variable to split on. Weights updated between sweeps.

Another potential improvement: main effect demeaning:

  • Append a “main effect estimate”, \(\hat{y}_{\text{reference}}\), as a column in BART.

  • Default to OLS estimate of \(\hat{y}=\mathbf{x}\beta\). Works pretty well!

  • Can include any \(\hat{y}\) (or multiple of them) you want depending on what you know about the DGP. Could be different BART priors, regularized linear regression estimates, or specific basis regressions.

  • If \(p\) is large, put higher prior weight on splitting on the main effect estimate column.

Testing on unseen data

  • Because we have the trees built, can apply the learnt rules to unseen \(\mathbf{X}\).
  • In areas where there is no training data, BART, being a tree based method, predicts the same constant value for all points falling in the same leaf node… arguably more of a feature than a bug in some applications.
  • Every testing point falls into some leaf node.

Extrapolation with Gaussian processes

  • (Wang, He, and Hahn 2024) replace BART predictions (after forests are built) with Gaussian process predictions. Only for points deemed to be in extrapolation zone.

    Visual illustration of GP extrapolation BART for a single iteration of a single tree in a forest

Don’t miss the forest for the trees

  • Individual trees are not identified!

  • Estimating \(E(y \mid \mathbf{x})\) is the goal, and the sum of trees (the “forest”) in aggregate accomplishes that goal.

What are the secret ingredients?

  1. Boosting is smart: A single decision tree can fit data really well if it is very large. However, it is prone to over-fitting and not sustaining that performance to unseen data, limiting its usefulness. Many small trees each fitting a portion of the data (in BART sequentially through the “leftovers” of the previous tree) is a very smart idea! Many trees can yield similar likelihoods, and with a sum of trees so this is even more pronounced, so trees are certainly not identified… but this is feature not a bug.

    The sum-of-trees model, with its abundance of unidentified parameters, allows for “fit” to be freely reallocated from one tree to another. Because each move makes only small incremental changes to the fit, we can imagine the algorithm as analogous to sculpting a complex figure by adding and subtracting small dabs of clay.

    The original BART paper (Hugh A. and Chipman, George, and McCulloch 2012)

  2. Well calibrated regularization priors are even smarter: BART describes a model based approach to balance keeping trees (and their leaf contributions) small through regularization priors, while “listening” to the data (the likelihood) if larger trees are necessary. This is the mechanism behind behind the clay analogy in the quote above.

  3. Sampling different forests is maybe the smartest: Instead of finding an “optimal” single forest (like xgboost), BART tries out a bunch of configurations. BART samples from the posterior of \(f(\mathbf{x}_i)\) also yield uncertainty quantification. The BART MCMC usually works remarkably well. Summing together many forests together when taking the posterior mean has an extra smoothing property for point estimates.

  • But… it can take a while to converge (despite eventually getting there). XBART provides a way to approximate the BART posterior. Warm start BART, initializing BART with the XBART grown trees, serves to start the MCMC algorithm in a region of tree space that is higher probability.

  • Bonus: modeling the error variance \(\sigma^2\) explicitly seems to help BART fair better in noisier settings. Opens doors to different modeling regimes and better uncertainty quantification.

Visual

Examples

Tongue in cheek BART

  • A fun representation of the other BART.

Takeaways

  • BART is an excellent performer and has natural uncertainty quantification.

  • Bespoke BART models are easier to develop than other machine learning tools.

  • BART usually works pretty well off the bat in a variety of modeling landscapes.

References

Breiman, Leo, Jerome Friedman, Richard Olshen, and Charles Stone. 1984. “Cart.” Classification and Regression Trees.
Chipman, Hugh A and, Edward I George, and Robert E McCulloch. 2012. “BART: Bayesian Additive Regression Trees.” Annals of Applied Statistics 6 (1): 266–98.
Chipman, Hugh A, Edward I George, and Robert E McCulloch. 1998. “Bayesian CART Model Search.” Journal of the American Statistical Association 93 (443): 935–48.
Chipman, Hugh A, Edward I George, Robert E McCulloch, and Thomas S Shively. 2022. “mBART: Multidimensional Monotone BART.” Bayesian Analysis 17 (2): 515–44.
George, E. 1999. “Discussion of ‘Model Averaging and Model Search Strategies’ by m. Clyde.” In Bayesian Statistics 6–Proceedings of the Sixth Valencia International Meeting.
He, Jingyu, and P Richard Hahn. 2023. “Stochastic Tree Ensembles for Regularized Nonlinear Regression.” Journal of the American Statistical Association 118 (541): 551–70.
Linero, Antonio R. 2018. “Bayesian Regression Trees for High-Dimensional Prediction and Variable Selection.” Journal of the American Statistical Association 113 (522): 626–36.
———. 2024. “Generalized Bayesian Additive Regression Trees Models: Beyond Conditional Conjugacy.” Journal of the American Statistical Association, 1–14.
Ronen, Omer, Theo Saarinen, Yan Shuo Tan, James Duncan, and Bin Yu. 2022. “A Mixing Time Lower Bound for a Simplified Version of BART.” arXiv Preprint arXiv:2210.09352.
Sparapani, Rodney, Charles Spanbauer, and Robert McCulloch. 2021. “Nonparametric Machine Learning and Efficient Computation with Bayesian Additive Regression Trees: The BART r Package.” Journal of Statistical Software 97: 1–66.
Tan, Yaoyuan Vincent, and Jason Roy. 2019. “Bayesian Additive Regression Trees and the General BART Model.” Statistics in Medicine 38 (25): 5048–69.
Wang, Meijia, Jingyu He, and P Richard Hahn. 2024. “Local Gaussian Process Extrapolation for BART Models with Applications to Causal Inference.” Journal of Computational and Graphical Statistics 33 (2): 724–35.