Topic 12 Decision Trees
Learning Goals
- Clearly describe the recursive binary splitting algorithm for tree building for both regression and classification
- Compute the weighted average Gini index to measure the quality of a classification tree split
- Compute the sum of squared residuals to measure the quality of a regression tree split
- Explain how recursive binary splitting is a greedy algorithm
- Explain how different tree parameters relate to the bias-variance tradeoff
Slides from today are available here.
Trees in tidymodels
To build tree models in tidymodels
, first load the package and set the seed for the random number generator to ensure reproducible results:
library(dplyr)
library(readr)
library(ggplot2)
library(tidymodels)
tidymodels_prefer()
set.seed(___) # Pick your favorite number to fill in the parentheses
To fit a decision tree, we can adapt the following:
<- decision_tree() %>%
ct_spec set_engine(engine = 'rpart') %>%
set_args(cost_complexity = NULL, #default is 0.01 (used for pruning a tree)
min_n = NULL, #min number of observations to try split: default is 20 [I think the documentation has a typo and says 2] (used to stop early)
tree_depth = NULL) %>% #max depth, number of branches/splits to get to any final group: default is 30 (used to stop early)
set_mode('classification') # change this for regression tree
<- recipe(___ ~ ___, data = ______)
data_rec
<- workflow() %>%
data_wf add_model(ct_spec) %>%
add_recipe(data_rec)
<- data_wf %>% # or use tune_grid() to tune any of the parameters above
fit_mod fit(data = _____)
Visualizing and interpreting the “best” tree
# Plot the tree (make sure to load the rpart.plot package first)
%>%
fit_mod extract_fit_engine() %>%
rpart.plot()
# Get variable importance metrics
# Sum of the goodness of split measures (impurity reduction) for each split for which it was the primary variable.
%>%
fit_mod extract_fit_engine() %>%
pluck('variable.importance')
Exercises Part 1
You can download a template RMarkdown file to start from here.
Context
Before proceeding, install the rpart
and rpart.plot
packages (for building and plotting decision trees) by entering install.packages(c("rpart", "rpart.plot"))
in the Console.
Our goal will be to classify types of urban land cover in small subregions within a high resolution aerial image of a land region. Data from the UCI Machine Learning Repository include the observed type of land cover (determined by human eye) and “spectral, size, shape, and texture information” computed from the image. See this page for the data codebook.
Source: https://ncap.org.uk/sites/default/files/EK_land_use_0.jpg
library(dplyr)
library(readr)
library(ggplot2)
library(rpart.plot)
library(tidymodels)
tidymodels_prefer()
# Read in the data
<- read_csv("https://ajohns24.github.io/portfolio/data/land_cover.csv")
land
# There are 9 land types, but we'll focus on 3 of them
<- land %>%
land filter(class %in% c("asphalt", "grass", "tree"))
Exercise 1: Core theme: parametric/nonparametric
What does it mean for a method to be nonparametric? In general, when might we prefer nonparametric to parametric methods? When might we not?
Where do you think trees fall on the parametric/nonparametric spectrum?
Exercise 2: Core theme: Tuning parameters and the BVT
The key feature governing complexity of a tree model is the number of splits used in the tree. How is the number of splits related to model complexity, bias, and variance?
In practice, the number of splits is controlled indirectly through the following tuning parameters. For each, discuss how low/high parameter settings would affect the number of tree splits.
min_n
: the minimum number of observations that must exist in a node in order for a split to be attempted.cost_complexity
: complexity parameter. Any split that does not increase node purity bycost_complexity
is not attempted.depth
: Set the maximum depth of any node of the final tree. The depth of a node is the number of branches that need to be followed to get to a given node from the root node. (The root node has depth 0.)
Exercise 3: Building trees in tidymodels
Fit a tree model for the class
outcome (land type), and allow all possible predictors to be considered (~ .
in the model formula).
- Use 10-fold CV.
- Choose a final model whose test overall accuracy is within one standard error of the overall best metric.
- The Gini index impurity measure can be a minimum of zero and has an upper bound of 1.
- Try a sequence of 30
cost_complexity
values from 0.00001 to 0.1 (this will take up to 5 minutes).
# Make sure you understand what each line of code is doing
set.seed(123) # don't change this
<- vfold_cv(land, v = 10)
data_fold
<- decision_tree() %>%
ct_spec_tune set_engine(engine = 'rpart') %>%
set_args(cost_complexity = tune(),
min_n = 2,
tree_depth = NULL) %>%
set_mode('classification')
<- recipe(class ~ ., data = land)
data_rec
<- workflow() %>%
data_wf_tune add_model(ct_spec_tune) %>%
add_recipe(data_rec)
<- grid_regular(cost_complexity(range = c(____, ____)), levels = ____)
param_grid
<- tune_grid(
tune_res
data_wf_tune, resamples = data_fold,
grid = param_grid,
metrics = metric_set(accuracy) #change this for regression trees
)
- Make a plot of test performance versus the
cost_complexity
tuning parameter. Does it look as you expected?
autoplot(tune_res) + theme_classic()
- Now choose the cost_complexity value that gives the simplest tree (high or low
cost_complexity
?) within 1 SE of the max accuracy. Pull out the CV accuracy for the chosencost_complexity
.
<- select_by_one_std_err(tune_res, metric = 'accuracy', desc(cost_complexity))
best_complexity <- finalize_workflow(data_wf_tune, best_complexity)
data_wf_final
<- fit(data_wf_final, data = land)
land_final_fit
%>%
tune_res collect_metrics() %>%
filter(cost_complexity == ________)
Exercise 4: Visualizing trees
Let’s visualize the difference between the trees learned under cost_complexity
parameters. The code below fits a tree for a lower than optimal cost_complexity
value (tree_mod_lowcp
) and a higher than optimal cost_complexity
(tree_mod_highcp
). We then plot these trees (1st and 3rd) along with our best tree (2nd).
Look at page 3 of the rpart.plot
package vignette (an example-heavy manual) to understand what the plot components mean.
<- fit(
tree_mod_lowcp %>%
data_wf_tune update_model(ct_spec_tune %>% set_args(cost_complexity = .00001)),
data = land
)<- fit(
tree_mod_highcp %>%
data_wf_tune update_model(ct_spec_tune %>% set_args(cost_complexity = .1)),
data = land
)
# Plot all 3 trees in a row
par(mfrow = c(1,3))
%>% extract_fit_engine() %>% rpart.plot()
tree_mod_lowcp %>% extract_fit_engine() %>% rpart.plot()
land_final_fit %>% extract_fit_engine() %>% rpart.plot() tree_mod_highcp
Verify for a couple of splits the idea of increasing node purity/homogeneity in tree-building. (How is this idea reflected in the numbers in the plot output?)
Tuning classification trees (like with the
cost_complexity
parameter) is also referred to as “pruning”. Why does this make sense? NOTE: If “pruning” is a new word to you, first Google it.
Exercises Part 2
Exercise 5: Predictions from Trees
- Looking at the plot of the best fitted tree, manually make a soft (probability) and hard (class) prediction for the case shown below.
# Pick out training case 2 to make a prediction
<- land[2,]
test_case # Show only the needed predictors
%>% select(NDVI, Bright_100, SD_NIR, GLCM2_100)
test_case
%>% extract_fit_engine() %>% rpart.plot() land_final_fit
- Verify your predictions with the
predict()
function. (Note: we introduced this code in Logistic Regression, but this type of code applies to any classification model fit intidymodels
).
# Soft (probability) prediction
predict(land_final_fit, new_data = test_case, type = "prob")
# Hard (class) prediction
predict(land_final_fit, new_data = test_case, type = "class")
Exercise 6: Variable importance in trees
We can obtain numerical variable importance measures from trees. These measure, roughly, “the total decrease in node impurities from splitting on the variable” (even if the variable isn’t ultimately used in the split).
What are the 3 most important predictors by this measure? Does this agree with you might have expected based on the plot of the best fitted tree? What might greedy behavior have to do with this?
%>%
land_final_fit extract_fit_engine() %>%
pluck('variable.importance')
Exercise 7: REGRESSION trees?!
Trees can also be used for regression! Let’s work through a step of building a regression tree by hand.
For the two possible splits below, determine the better split for the tree by computing the sum of squared residuals as the measure of node impurity. (The numbers following Yes:
and No:
indicate the outcome value of the cases in the left (Yes) and right (No) regions.)
1: x1 < 3
Split - Yes: 1, 1, 2, 4
- No: 2, 2, 4, 4
2: x1 < 4
Split - Yes: 1, 1, 2
- No: 2, 2, 4, 4, 4
In case you want to explore building regression trees in R, try out the following exercises using the College data from the ISLR
package. Our goal was to predict graduation rate (Grad.Rate
) as a function of other predictors. You can look at the data codebook with ?College
in the Console.
library(ISLR)
data(College)
# A little data cleaning
<- College %>%
college_clean mutate(school = rownames(College)) %>%
filter(Grad.Rate <= 100) # Remove one school with grad rate of 118%
rownames(college_clean) <- NULL # Remove school names as row names
- Adapt our general decision tree code for the regression setting by adapting the metric used to pick the final model. (Note how other parts stay the same!)
Plot test performance as a function of cost_complexity
, and comment on the shape of the plot.
Plot the “best” tree. (See page 3 of the rpart.plot
package vignette for a refresher on what the plot shows.) Do the sequence of splits and outcomes in the leaf nodes make sense?
Look at the variable importance metrics from the best tree. Do the most important variables align with your intuition?