draw_tree {treeheatr} | R Documentation |
Draws the conditional decision tree.
Description
Draws the conditional decision tree output from partykit::ctree(), utilizing ggparty geoms: geom_edge, geom_edge_label, geom_node_label.
Usage
draw_tree(
dat,
fit,
term_dat,
layout,
target_cols = NULL,
title = NULL,
tree_space_top = 0.05,
tree_space_bottom = 0.05,
print_eval = FALSE,
metrics = NULL,
x_eval = 0,
y_eval = 0.9,
task = c("classification", "regression"),
par_node_vars = list(label.size = 0, label.padding = unit(0.15, "lines"), line_list =
list(aes(label = splitvar)), line_gpar = list(list(size = 9)), ids = "inner"),
terminal_vars = list(label.padding = unit(0.25, "lines"), size = 3, col = "white"),
edge_vars = list(color = "grey70", size = 0.5),
edge_text_vars = list(color = "grey30", size = 3, mapping = aes(label =
paste(breaks_label, "*NA")))
)
Arguments
dat |
Dataframe with samples from original dataset ordered according to the clustering within each leaf node. |
fit |
party object, e.g., as output from partykit::ctree() |
term_dat |
Dataframe for terminal nodes, must include these columns: id, x, y and y_hat. |
layout |
Dataframe of layout of all nodes, must include these columns: id, x, y and y_hat. |
target_cols |
Character vectors representing the hex values of different level colors for targets, defaults to viridis option B. |
title |
Character string for plot title. |
tree_space_top |
Numeric value to pass to expand for top margin of tree. |
tree_space_bottom |
Numeric value to pass to expand for bottom margin of tree. |
print_eval |
Logical. If TRUE, print evaluation of the tree performance. |
metrics |
A set of metric functions to evaluate decision tree, defaults to common metrics for classification/regression problems. Can be defined with 'yardstick::metric_set'. |
x_eval |
Numeric value indicating x position to print performance statistics. |
y_eval |
Numeric value indicating y position to print performance statistics. |
task |
Character string indicating the type of problem, either 'classification' (categorical outcome) or 'regression' (continuous outcome). |
par_node_vars |
Named list containing arguments to be passed to the 'geom_node_label()' call for non-terminal nodes. |
terminal_vars |
Named list containing arguments to be passed to the 'geom_node_label()' call for terminal nodes. |
edge_vars |
Named list containing arguments to be passed to the 'geom_edge()' call for tree edges. |
edge_text_vars |
Named list containing arguments to be passed to the 'geom_edge_label()' call for tree edge annotations. |
Value
A ggplot2 grob object of the decision tree.
Examples
x <- compute_tree(penguins, target_lab = 'species')
draw_tree(x$dat, x$fit, x$term_dat, x$layout)