plotTrees {bartMan}R Documentation

Plot Trees with Customisations

Description

This function plots trees from a list of tidygraph objects. It allows for various customisations such as fill colour based on node response or value, node size adjustments, and color palettes.

Usage

plotTrees(
  trees,
  iter = NULL,
  treeNo = NULL,
  fillBy = NULL,
  sizeNodes = FALSE,
  removeStump = FALSE,
  selectedVars = NULL,
  pal = rev(colorRampPalette(c("steelblue", "#f7fcfd", "orange"))(5)),
  center_Mu = TRUE,
  cluster = NULL
)

Arguments

trees

A data frame of trees.

iter

An integer specifying the iteration number of trees to be included in the output. If NULL, trees from all iterations are included.

treeNo

An integer specifying the number of the tree to include in the output. If NULL, all trees are included.

fillBy

A character string specifying the attribute to color nodes by. Options are 'response' for coloring nodes based on their mean response values or 'mu' for coloring nodes based on their predicted value, or NULL for no specific fill attribute.

sizeNodes

A logical value indicating whether to adjust node sizes. If TRUE, node sizes are adjusted; if FALSE, all nodes are given the same size.

removeStump

A logical value. If TRUE, then stumps are removed from plot.

selectedVars

A vector of selected variables to display. Either a character vector of names or the variables column number.

pal

A colour palette for node colouring. Palette is used when 'fillBy' is specified for gradient colouring.

center_Mu

A logical value indicating whether to center the color scale for the 'mu' attribute around zero. Applicable only when 'fillBy' is set to "mu".

cluster

A character string that specifies the criterion for reordering trees in the output. Currently supports "depth" for ordering by the maximum depth of nodes, and "var" for a clustering based on variables. If NULL, no reordering is performed.

Value

A ggplot object representing the plotted trees with the specified customisations.

Examples

if (requireNamespace("dbarts", quietly = TRUE)) {
 # Load the dbarts package to access the bart function
 library(dbarts)
 # Get Data
 df <- na.omit(airquality)
 # Create Simple dbarts Model For Regression:
 set.seed(1701)
 dbartModel <- bart(df[2:6],
   df[, 1],
   ntree = 5,
   keeptrees = TRUE,
   nskip = 10,
   ndpost = 10
 )
 # Tree Data
 trees_data <- extractTreeData(model = dbartModel, data = df)
 plotTrees(trees = trees_data, fillBy = 'response', sizeNodes = TRUE)
}


[Package bartMan version 0.1.1 Index]