January 2025
Theme borrowed from Emi Tanaka
Also used elements from this excellent quarto tutorial
Notes and discussion from Andrew Herren, Rafael Alcantara, Rob McCulloch, Hedibert Lopes, and Richard Hahn provided excellent insights for these notes.
BART (Hugh A. and Chipman, George, and McCulloch 2012) is a prior over functions. The prior is a sum of decision trees, aka a forest.
BART is a Gaussian process, with a kernel learned from the data.
Anything random forest (or XGBoost) can do, BART can (usually) do better, particularly learning complex functions in noisy situations.
BART’s Bayesian framework invites custom models. stochtree enables such modifications without having to edit C++.
Implements BART and other stochastic tree variants. Fast, easy to use, and easy to customize.
Implements many bells and whistles.
Decision trees can be represented mathematically:
\[ f(\mathbf{x})=\sum_{j=1}^{\text{# partitions}} \mu_{j}\cdot\mathbf{1}\left(\mathbf{x}\in \mathcal{S}_j\right) \]
Which says that \(\mathbf{x}\) is partitioned into sections, and the mean in that section is the value we care about.
A sample of a decision tree
I wanna predict how well I will sleep.
Expect some interactions between my daily habits with respect to my next nights sleep. If its warm out and I slept well the night before, I tend to sleep well the next night. This alone may not be enough…
However, don’t want to account for too many if else statements, or else we’ll overfit the idiosyncracies in my day.
BART
\[ \begin{align} y_i &= \sum_{j=1}^{M}f(\mathbf{x}, T_j, \Theta_j)+\varepsilon, \qquad \varepsilon\sim N(0, \sigma^2)\\ E(y\mid \mathbf{x}) &= \text{signal} =\sum_{j=1}^{M}f(\mathbf{x}, T_j, \Theta_j)\\ \text{noise } = \sigma^2 &\sim \text{IG}\left(\frac{\nu}{2},\frac{\nu*\lambda}{2}\right) \end{align} \]
1 portion done, two more to go
Model \(N_{\text{MCMC}}\) iterations representing draws from probability distributions for \(\hat{f}\).
Uncertainty comes into play in a couple ways: for each iteration, trees will look a little different.
Each tree splits the data and the predicted outcome at is randomly drawn centered at the mean in that partition.
The variance in the error term (whatever isn’t accounted for by the trees) is another random draw per iteration of the algorithm.
So far, we have mostly focused on how to build a Bayesian tree.
Borrowing ideas from boosting, can sum trees by fitting subsequent trees to the residual, \(r = y-f(\mathbf{x})\), of the previous trees.
Recall, \(\varepsilon\sim N(0,\sigma^2)\). This is the “model” uncertainty, aka remaining uncertainty besides what we can explain \(f(\mathbf{x})\). Mean zero implies we assume \(E(y\mid\mathbf{x})=f(\mathbf{x})\longrightarrow y\sim N(f(\mathbf{x}), \sigma^2)\).
Prior on error term variance \(\sigma^2\): calibrated such that 90% chance that \(\sigma^2\) will be less than \(\hat{\sigma}^2\) estimated by a linear regression fit on the data.
Later, we will see how to model \(\sigma^2\) to be a function of \(\mathbf{x}\).
Constant predicted outcome in each leaf of each tree is drawn from \(\mu_{qj}\sim N(0, \tau)\), which is conjugate with the normal outcome likelihood! Centered at \(0\), shift \(y\) to be between \(-0.5\) and \(0.5\).
So, \(\tau = \frac{0.5}{k\sqrt{m}}\), where \(m\) is the number of trees. Choose \(k\) such that there is a \(\Phi^{-1}(k)\)% chance the tree sums will be in that range. The set of leaf parameters for a tree \(j\) is described by the set \(\Theta_{j}\).
Has a shrinking effect; helps ensure small contributions from components of the tree. The conjugacy here will come in handy later.
The GOAT wikipedia page on conjugate priors.
\[ \Pr(\text{node keeps growing at depth $d$}) = \frac{\alpha}{(1+d)^{\beta}} \]
Regularizes each tree to not split too much, i.e. a subtle push to “additivity” in a sense.
(Linero 2018) linked here, use a Dirichlet variable selection prior instead. Could also provide your own variable splitting weights in stochtree.
The tree prior is a “dilution prior” (George 1999).
Dilutes probability within similar clusters of trees.
This places itself out in the posterior. If two trees fit equally well, will place more probability on more “unique” tree.
Ex) two trees splitting on two very correlated variables will have \(\approx\) equal prior weight, lowering each of their probabilities.
Choose priors such that the tree components are independent of each other and of \(\sigma^2\), the variance parameter in \(\varepsilon\).
That is:
\[ \begin{align} \pi\left((T_1,\Theta_{1}), \ldots, (T_m,\Theta_m),\sigma\right)&=\left[\prod_{j=1}^{m}\pi(\Theta_j, T_j)\right]\pi(\sigma^2)\\ &=\left[\prod_{j=1}^{m}\pi(\Theta_j\mid T_j)\pi(T_j)\right]\pi(\sigma^2)\\ \pi(\Theta_j\mid T_j)&=\prod_{q}^{\text{# partitions in tree $j$}}\pi(\mu_{qj}\mid T_j) \end{align} \]
We have now specified the priors: \(\pi(T_j)\), \(\pi(\mu_{qj}\mid T_j)\), and \(\pi(\sigma^2)\).
For example, if we want to build “monotone” trees (Hugh A. Chipman et al. 2022), one way to do so is ensuring every leaf node would be greater/less than the previous. Now the leaf nodes must be dependent!
Without independence, harder to sample the posterior. (Hugh A. Chipman et al. 2022) use a grid to evaluate the resulting posterior.
| Depth | Probability of node being terminal |
|---|---|
| 1 | 0.05 |
| 2 | 0.55 |
| 3 | 0.28 |
| 4 | 0.09 |
| \(\geq\) 5 | 0.03 |
Want to sample from (a single, for now) tree posterior, \(\Pr(T_{j}\mid y, \sigma^2) \propto f(y\mid T_{j}, \sigma^2)\pi(T_{j})\). Of course, we are actually modeling \(\Pr((T_{j},\Theta_{j})\mid y, \sigma^2)\), but because of conjugacy of prior and likelihood on the leaf parameters \(\mu_{qj}\in \Theta_{j}\) within a tree \(j\), we can “average” or integrate out \(\Theta_{j}\):
\(f(y\mid T_{j}) = \int f(y\mid \Theta_{j}, T_{j}, \sigma^2)\pi(\Theta_{j}\mid T_{j}, \sigma^2)\text{d}\Theta_{j}\),
which conveniently has closed form solution (thanks conjugacy!).1
\[ \Pr(T\mid y)=\frac{\Pr(y\mid T)\pi(T)}{\int_{\text{tree space}}\Pr(y\mid T)\pi(T)\text{d}\text{(trees)}} \]
Accept/reject proposed trees based on evaluation of product of \(f(y\mid T)\pi(T)\) and the transition probability of new tree versus the previous tree.
To approximate the correct distribution (eventual convergence) for \(\Pr(T\mid y)\), need an irreducible, aperiodic, and recurrent Markov chain.
So need to satisfy these properties and show empirical evidence of reasonably quick convergence.
Want to build reversible trees to guarantee the Markov chain will be irreducible (not always true for any Markov chain).
MCMC algorithm of (Hugh A. Chipman, George, and McCulloch 1998) is concise algorithm and works pretty well in practice.
1) Grow: Randomly pick a terminal node and split on a new variable at random, creating a new split rule and two new leaf nodes.(50% chance)
2) Prune: Randomly choose a parent of two leaf nodes and remove its leaves, making that parent a terminal/leaf node.(50% chance)
Evaluate the new proposed trees based on:
\[ \alpha(T^*\mid T^k)=\min\left(\frac{Q(T^k\mid T^*)g(T^*)}{Q(T^*\mid T^k)g(T^k)},1\right) \]
Where the transition kernel, \(q()\) is defined by the moves on the previous slide. \(^*\) signifies the proposed new tree and the \(^k\) index represents the previous MCMC iteration.\(\frac{g(T^*)}{g(T^k)}=\frac{\Pr(T^*\mid y)}{\Pr(T^k\mid y)}=\frac{f(y\mid T^*)\pi(T^*)f(y)}{f(y\mid T^k)\pi(T^k)f(y)}\) and the marginal likelihood \(f(y)\) cancels out.
\(Q(T^k \mid T^*)\) is probability of proposing tree \(T^k\) given current state (death) and \(Q(T^* \mid T^k)\) is the probability of proposing tree \(T^*\) given the current tree structure \(T^k\) (a birth).
For the “birth move”, the probability of splitting off of one of the current \(n_\text{bottom}\) nodes is:
\[ Q(T^k\mid T^*) = \frac{1}{n_{\text{bottom}}}\Pr(\text{birth})\Pr(\text{selecting variable})\Pr(\text{splitting along variable}) \]
Where \(\Pr(\text{birth})=0.5\) (usually), and the other two variables are usually uniform
The ratio: \(\frac{g(T')}{g(T)}=\frac{f(y\mid T')\pi(T')}{f(y\mid T)\pi(T)}\) plays a critical role. The prior for the tree was proposed earlier and the marginal likelihood \(f(y\mid T)\) is analytically available for both the “left” and “right” splits of a tree.
Since we assume independence across leaf nodes, marginal likelihood is just the product between leaves1.
See (Sparapani, Spanbauer, and McCulloch 2021) for more detail, or appendix C.1.2 of (Tan and Roy 2019).
Finally,
\[ \frac{\pi(T^*)}{\pi(T^k)} = \frac{\left[1-\Pr(\text{tree grows left}\right]\cdot\left[1-\Pr(\text{tree grows right})\right]\cdot \Pr(\text{tree grows})\Pr(\text{which})\Pr(\text{where})}{1-\Pr(\text{tree grows})} \]
\(\Pr(\text{grows})\) corresponds to \(\Pr(\text{node keeps growing at depth $d$}) = \frac{\alpha}{(1+d)^{\beta}}\). Which and where are shorthand for \(\Pr(\text{selecting variable})\) and \(\Pr(\text{splitting along variable})\) respectively.
A single tree may take a really long time to mix. Can get stuck at a large tree.
Fitting many trees helps mix better. Since each tree contributes a small part to the fit, with enough trees can find a reasonable \(f(\mathbf{x}_i)\) sooner.
Sample a tree with the Metropolis Hastings birth death with the residual acting as the outcome (the first tree just uses the outcome).
Sample leaf node parameters via Gibbs. The posterior is also normal thanks to the conjugacy, and the residual still plays the role of the outcome.
Calculate the residual from this tree, move onto to building the next tree with this residual.
After going through all the trees, sample \(\sigma^2\sim \text{Inverse Gamma}\left(\frac{\nu}{2},\frac{\nu\cdot \lambda}{2}\right)\). Move onto next MCMC iteration of the sampler.
For trees \(j\in1, 2, \ldots, M\) and with \(-\) indicating conditioning on the other trees, \(\mathcal{T}\) is the set of all the trees, and \(\mathbf{\Theta}\) is the collection of all the \(\Theta_{j}\)’s, the set of leaf parameters for tree \(j\):
Sample \(T_j, \mu_{j}\mid \mathcal{T}_{-j}, \mathbf{\Theta}_{-j}, \sigma^2, y\).
Which because of the conjugacy of the outcome and the priors on the leaf mean parameters \(\mu_{qj}\) can be done as:
\(T_j\mid \mathcal{T}_{-j}, \mathbf{\Theta}_{-j}, \sigma^2, y\) and \(\mu_{j}\mid \mathcal{T}, \Theta_{-j}, \sigma^2, y\).
Sample \(\sigma^2\mid \mathcal{T}, \Theta, y\).
Since \(\Pr(T_j, \mu_{j}\mid \mathcal{T}_{-j}, \mathcal{\Theta}_{-j}, \sigma^2, y)\) depends on \((T_{-j}, \Theta_{-j}, y)\) only through the residual, we can write these steps conditional on the current residuals \(r_j\):
Sample \(T_j, \mu_{qj}\mid r_{j}, \sigma^2, y\) as \(T_j \mid r_j, \sigma^2\) (which is drawn from the birth/death MCMC process) and \(\mu_{qj}\mid T_{j}, r_{j}\) (whose posterior is a draw from a normal distribution). Once all the \(\mu_{qj}\in \Theta_{j}\) are calculated, we can update the residuals and then the next tree can be calculated.
After the \(m\) trees are sampled, sample \(\sigma^2\mid r\) , which is an Inverse Gamma posterior.
Repeat this procedure for \(N_{\text{MCMC}}\) draws. Usually want to “burn in” some until the sampler finds high probability trees… hard to know in practice when this is the case.
Posterior for \(y\mid\mathbf{x}\) is proportional to \(\text{likelihood}\times \text{prior}\).
Prior to seeing data, expect small contributions from each tree.
Likelihood of tree describes how well proposed tree “fits” the data.
So… trees can grow deeper if data permit.
The BART MCMC algorithm proposes one move per tree at a time and can take a while to mix.
Worryingly, can sometimes get stuck in poor fitting regions of the tree posterior space.
Ideally, the draws are random samples from the posterior and should not show any autocorrelation draw to draw, but BART often does not mix well enough.
(Ronen et al. 2022), albeit for a very simplified version of BART, show degradation of mixing scaling with \(n\).
Ideally starting the posterior at high probability trees and/or averaging over many independent chains of forests may be helpful.
With many redundant features, it is more likely for BART to be doomed by poor first split in tree.
Therefore, we’d want to help by providing variables with strong signal to split on (more to come).
(He and Hahn 2023) combines recursive partitioning (repeatedly calling the same splitting process) from CART (Breiman et al. 1984) with stochastic growing rules instead of greedy optimization.
From Bayes rule, can write, where \(s\) represents where to split, and \(v\) represents which variable to split on
\[\Pr(s,v) = \frac{f(s,v)\pi(s)}{\sum_{s_j}\sum_{v_j}f(s_j,v_j)\pi(s_j)}\]
\(f(s,v)\): marginal likelihood of tree splits, available in analytical form. No split corresponds to tree prior “done splitting” probability.
Build tree fully for every tree in forest. Still pass residuals to next tree.
Restart entire forest (remake every tree) for each iteration.
Pass the variance parameter at the end of every iteration.
Other computational tricks incorporated to make tree building faster.
Each XBART iteration is slower than the BART MCMC one at a time forest iteration. BUT… mixing is considerably faster. BART usually needs thousands of MCMC draws, XBART merely 5-20 iterations!
Amounts to 20-30x speedup and capability to handle larger datasets.
However, XBART is not a valid BART posterior (trees aren’t built reversibly).
The prior on the value drawn in the leaves has a prior on it now!
\[ \mu_{qj}\sim N(0,\tau) \; ; \; \tau\sim \text{IG}(a,b) \]
Rather than choosing one uniformly at random in BART, or considering all variables as Bayes rule approach above, update variable weights based on (Linero 2018). Dirichlet prior on which variable to split on. Weights updated between sweeps.
Append a “main effect estimate”, \(\hat{y}_{\text{reference}}\), as a column in BART.
Default to OLS estimate of \(\hat{y}=\mathbf{x}\beta\). Works pretty well!
Can include any \(\hat{y}\) (or multiple of them) you want depending on what you know about the DGP. Could be different BART priors, regularized linear regression estimates, or specific basis regressions.
If \(p\) is large, put higher prior weight on splitting on the main effect estimate column.
(Wang, He, and Hahn 2024) replace BART predictions (after forests are built) with Gaussian process predictions. Only for points deemed to be in extrapolation zone.
Individual trees are not identified!
Estimating \(E(y \mid \mathbf{x})\) is the goal, and the sum of trees (the “forest”) in aggregate accomplishes that goal.
Boosting is smart: A single decision tree can fit data really well if it is very large. However, it is prone to over-fitting and not sustaining that performance to unseen data, limiting its usefulness. Many small trees each fitting a portion of the data (in BART sequentially through the “leftovers” of the previous tree) is a very smart idea! Many trees can yield similar likelihoods, and with a sum of trees so this is even more pronounced, so trees are certainly not identified… but this is feature not a bug.
The sum-of-trees model, with its abundance of unidentified parameters, allows for “fit” to be freely reallocated from one tree to another. Because each move makes only small incremental changes to the fit, we can imagine the algorithm as analogous to sculpting a complex figure by adding and subtracting small dabs of clay.
The original BART paper (Hugh A. and Chipman, George, and McCulloch 2012)
Well calibrated regularization priors are even smarter: BART describes a model based approach to balance keeping trees (and their leaf contributions) small through regularization priors, while “listening” to the data (the likelihood) if larger trees are necessary. This is the mechanism behind behind the clay analogy in the quote above.
Sampling different forests is maybe the smartest: Instead of finding an “optimal” single forest (like xgboost), BART tries out a bunch of configurations. BART samples from the posterior of \(f(\mathbf{x}_i)\) also yield uncertainty quantification. The BART MCMC usually works remarkably well. Summing together many forests together when taking the posterior mean has an extra smoothing property for point estimates.
But… it can take a while to converge (despite eventually getting there). XBART provides a way to approximate the BART posterior. Warm start BART, initializing BART with the XBART grown trees, serves to start the MCMC algorithm in a region of tree space that is higher probability.
Bonus: modeling the error variance \(\sigma^2\) explicitly seems to help BART fair better in noisier settings. Opens doors to different modeling regimes and better uncertainty quantification.
Learn more: Absolute Position
BART is an excellent performer and has natural uncertainty quantification.
Bespoke BART models are easier to develop than other machine learning tools.
BART usually works pretty well off the bat in a variety of modeling landscapes.
Learn more: Quarto Presentations