plotTrees {bartMan} | R Documentation |
Plot Trees with Customisations
Description
This function plots trees from a list of tidygraph objects. It allows for various customisations such as fill colour based on node response or value, node size adjustments, and color palettes.
Usage
plotTrees(
trees,
iter = NULL,
treeNo = NULL,
fillBy = NULL,
sizeNodes = FALSE,
removeStump = FALSE,
selectedVars = NULL,
pal = rev(colorRampPalette(c("steelblue", "#f7fcfd", "orange"))(5)),
center_Mu = TRUE,
cluster = NULL
)
Arguments
trees |
A data frame of trees. |
iter |
An integer specifying the iteration number of trees to be included in the output. If NULL, trees from all iterations are included. |
treeNo |
An integer specifying the number of the tree to include in the output. If NULL, all trees are included. |
fillBy |
A character string specifying the attribute to color nodes by. Options are 'response' for coloring nodes based on their mean response values or 'mu' for coloring nodes based on their predicted value, or NULL for no specific fill attribute. |
sizeNodes |
A logical value indicating whether to adjust node sizes. If TRUE, node sizes are adjusted; if FALSE, all nodes are given the same size. |
removeStump |
A logical value. If TRUE, then stumps are removed from plot. |
selectedVars |
A vector of selected variables to display. Either a character vector of names or the variables column number. |
pal |
A colour palette for node colouring. Palette is used when 'fillBy' is specified for gradient colouring. |
center_Mu |
A logical value indicating whether to center the color scale for the 'mu' attribute around zero. Applicable only when 'fillBy' is set to "mu". |
cluster |
A character string that specifies the criterion for reordering trees in the output. Currently supports "depth" for ordering by the maximum depth of nodes, and "var" for a clustering based on variables. If NULL, no reordering is performed. |
Value
A ggplot object representing the plotted trees with the specified customisations.
Examples
if (requireNamespace("dbarts", quietly = TRUE)) {
# Load the dbarts package to access the bart function
library(dbarts)
# Get Data
df <- na.omit(airquality)
# Create Simple dbarts Model For Regression:
set.seed(1701)
dbartModel <- bart(df[2:6],
df[, 1],
ntree = 5,
keeptrees = TRUE,
nskip = 10,
ndpost = 10
)
# Tree Data
trees_data <- extractTreeData(model = dbartModel, data = df)
plotTrees(trees = trees_data, fillBy = 'response', sizeNodes = TRUE)
}