Topic 11 Lasso & Logistic Regression

Learning Goals

  • Describe how you can use LASSO for logistic regression model
  • Calculate (by hand from confusion matrices) and contextually interpret overall accuracy, sensitivity, and specificity
  • Construct and interpret plots of predicted probabilities across classes
  • Explain how a ROC curve is constructed and the rationale behind AUC as an evaluation metric
  • Appropriately use and interpret the no-information rate to evaluate accuracy metrics





LASSO for logistic regression in tidymodels

To build LASSO models for logistic regression in tidymodels, first load the package and set the seed for the random number generator to ensure reproducible results:

library(dplyr)
library(readr)
library(ggplot2)
library(tidymodels)
library(probably) #install.packages('probably')
tidymodels_prefer()

set.seed(2023) # or pick your favorite number to fill in the parentheses




Exercises

You can download a template RMarkdown file to start from here.

Context

Before proceeding, install the pROC package (utilities for evaluating classification models with ROC curves) by entering install.packages("pROC") in the Console.

We’ll continue working with the spam dataset from last time.

  • spam: Either spam or not spam (outcome)
  • word_freq_WORD: percentage of words in the e-mail that match WORD (0-100)
  • char_freq_CHAR: percentage of characters in the e-mail that match CHAR (e.g., exclamation points, dollar signs)
  • capital_run_length_average: average length of uninterrupted sequences of capital letters
  • capital_run_length_longest: length of longest uninterrupted sequence of capital letters
  • capital_run_length_total: sum of length of uninterrupted sequences of capital letters

Our goal will be to use email features to predict whether or not an email is spam - essentially, to build a spam filter!

library(dplyr)
library(readr)
library(ggplot2)
library(tidymodels)
library(probably) #install.packages('probably')
tidymodels_prefer()


spam <- read_csv("https://www.dropbox.com/s/leurr6a30f4l32a/spambase.csv?dl=1")

# A little data cleaning to remove the space in "not spam"
spam <- spam %>%
    mutate(spam = ifelse(spam=="spam", "spam", "not_spam"))

Exercise 1: Conceptual warmup

LASSO for the logistic regression setting works analogously to the regression setting. How would you expect a plot of test roc_auc vs. \(\lambda\) to look, and why?

Exercise 2: Implementing LASSO logistic regression in tidymodels

Fit a LASSO logistic regression model for the spam outcome, and allow all possible predictors to be considered (~ . in the model formula).

  • Use 10-fold CV.
  • Initially try a sequence of 100 \(\lambda\)’s from 1 to 10.
    • Diagnose whether this sequence should be updated by looking at the plot of test AUC vs. \(\lambda\).
    • If needed, adjust the max or min value in the sequence up or down by a factor of 10. (You’ll be able to determine from the plot whether to adjust up or down.)
set.seed(123)

# Make sure you set reference level (to the outcome you are NOT interested in)
spam <- spam %>%
  mutate(spam = relevel(factor(spam), ref='not_spam')) #set reference level

data_cv10 <- vfold_cv(spam, v = 10)


# Logistic LASSO Regression Model Spec
logistic_lasso_spec_tune <- logistic_reg() %>%
    set_engine('glmnet') %>%
    set_args(mixture = 1, penalty = tune()) %>%
    set_mode('classification')

# Recipe
logistic_rec <- recipe(_____, data = spam) %>%
    step_normalize(all_numeric_predictors()) %>% 
    step_dummy(all_nominal_predictors())

# Workflow (Recipe + Model)
log_lasso_wf <- workflow() %>% 
    add_recipe(logistic_rec) %>%
    add_model(logistic_lasso_spec_tune) 

# Tune Model (trying a variety of values of Lambda penalty)
penalty_grid <- grid_regular(
  penalty(range = c(__,__)), #log10 transformed 
  levels = 100)

tune_output <- tune_grid( 
  log_lasso_wf, # workflow
  resamples = data_cv10, # cv folds
  metrics = metric_set(roc_auc,accuracy),
  control = control_resamples(save_pred = TRUE, event_level = 'second'),
  grid = penalty_grid # penalty grid defined above
)

# Visualize Model Evaluation Metrics from Tuning
autoplot(tune_output) + theme_classic()

Exercise 3: Inspecting the model

  1. Inspect the plot of CV AUC vs. \(\lambda\) once more (after adjusting the penalty grid).

Is anything surprising about the results relative to your expectations from Exercise 1? Brainstorm some possible explanations in consideration of the data context.

# Visualize Model Evaluation Metrics from Tuning
autoplot(tune_output) + theme_classic()
  1. Choose a final model whose CV AUC is within one standard error of the overall best metric. Comment on the variables that are removed from the model.
# Select Penalty
best_se_penalty <- select_by_one_std_err(tune_output, metric = 'roc_auc', desc(penalty)) # choose penalty value based on the largest penalty within 1 se of the highest CV roc_auc
best_se_penalty

# Fit Final Model
final_fit_se <- finalize_workflow(log_lasso_wf, best_se_penalty) %>% # incorporates penalty value to workflow 
    fit(data = spam)

final_fit_se %>% tidy()
  1. Compute the variable importance by how long it stayed in the model (with increasing \(\lambda\)).
glmnet_output <- final_fit_se %>% extract_fit_engine()
    
# Create a boolean matrix (predictors x lambdas) of variable exclusion
bool_predictor_exclude <- glmnet_output$beta==0

# Loop over each variable
var_imp <- sapply(seq_len(nrow(bool_predictor_exclude)), function(row) {
    this_coeff_path <- bool_predictor_exclude[row,]
    if(sum(this_coeff_path) == ncol(bool_predictor_exclude)){ return(0)}else{
    return(ncol(bool_predictor_exclude) - which.min(this_coeff_path) + 1)}
})

# Create a dataset of this information and sort
var_imp_data <- tibble(
    var_name = rownames(bool_predictor_exclude),
    var_imp = var_imp
)
var_imp_data %>% arrange(desc(var_imp))

Exercise 4: Interpreting evaluation metrics

Inspect the overall CV results for the “best” \(\lambda\), and compute the no-information rate (NIR):

# CV results for "best lambda"
tune_output %>%
    collect_metrics() %>%
    filter(penalty == best_se_penalty %>% pull(penalty))

# Count up number of spam and not_spam emails in the training data
spam %>%
    count(spam) # Name of the outcome variable goes inside count()

# Compute the NIR
  • Why is an AUC of 1 the best possible value for this metric? How does the AUC for our spam model look relative to this best value?

Exercise 5: Using the final model (choosing a threshold)

Once we’ve used LASSO to do model selection (to balance bias and variance), we need to use the final model to make predictions.

# Soft Predictions on Training Data
final_output <- final_fit_se %>% predict(new_data = spam, type='prob') %>% bind_cols(spam)
  1. Make an ROC curve of the sensitivity and specificity for the final model. First, consider sensitivity and specificity - what do these numbers mean? Do you think higher sensitivity or specificity would be more important in designing a spam filter? Comment about the shape of the ROC curve and what that tells you about the model.
# Use soft predictions
final_output %>%
    roc_curve(spam,.pred_spam,event_level = 'second') %>%
    autoplot()
  1. Similar to the ROC curve, we can calculate the sensitivity and specificity for a variety of threshold values. Then we will calculate two measures.
  • J index: sensitivity + specificity - 1
  • distance (Euclidean distance to [false positive = 0, sens = 1]): (1 - sensitivity) ^ 2 + (1 - specificity) ^ 2

We can use these measures to help us choose an optimal threshold.

# thresholds in terms of reference level
threshold_output <- final_output %>%
    threshold_perf(truth = spam, estimate = .pred_not_spam, thresholds = seq(0,1,by=.01)) 

# J-index v. threshold for not_spam
threshold_output %>%
    filter(.metric == 'j_index') %>%
    ggplot(aes(x = .threshold, y = .estimate)) +
    geom_line() +
    labs(y = 'J-index', x = 'threshold') +
    theme_classic()

threshold_output %>%
    filter(.metric == 'j_index') %>%
    arrange(desc(.estimate))


# Distance v. threshold for not_spam

threshold_output %>%
    filter(.metric == 'distance') %>%
    ggplot(aes(x = .threshold, y = .estimate)) +
    geom_line() +
    labs(y = 'Distance', x = 'threshold') +
    theme_classic()

threshold_output %>%
    filter(.metric == 'distance') %>%
    arrange(.estimate)



log_metrics <- metric_set(accuracy,sens,yardstick::spec)

final_output %>%
    mutate(.pred_class = make_two_class_pred(.pred_not_spam, levels(spam), threshold = ___ )) %>%
    log_metrics(truth = spam, estimate = .pred_class, event_level = 'second')

Choose a threshold and interpret overall accuracy - does this seem high? How can the no-information rate (NIR) help us interpret the overall accuracy?

  1. There are data contexts in which you might want to weigh the relative cost of a false negative to a false positive and incorporate that into your choice of threshold.

In this case you can specify:

  • \(cost\), the relative cost of a false negative classification (as compared with a false positive classification)
  • \(prev\), the prevalence, or the proportion of positive cases in the population

and calculate

\[r = \frac{1-prev}{cost*prev}\] This value \(r\) gives a measure of the relative cost of false positives to false negatives, weighting by the frequency of the two groups. If \(cost = 1\) and \(prev = 0.5\), \(r=1\).

Then we can update our two measures,

  • Weighted J index: sensitivity + r*specificity - 1
  • Weighted distance: (1 - sensitivity) ^ 2 + r*(1 - specificity) ^ 2

Computational Thinking Exercise: If you were to implement these weighted measures (j-index and distance), how would you go about using the following ROC output to calculate the measures and find the optimal threshold. Write our the steps in words using the names of functions you might use, but don’t actually code it up.

final_output %>%
    roc_curve(spam,.pred_spam,event_level = 'second')

Exercise 6: Algorithmic understanding for evaluation metrics

Inspect the iteration specific information from tuning the LASSO model (focus on the “best” penalty term).

tune_output %>% collect_metrics(summarize = FALSE) %>% filter(penalty == best_se_penalty %>% pull(penalty))

How is one row of information computed? Carefully describe (in words) the CV process for a single iteration to estimate each of CV roc_auc and accuracy (overall accuracy). Use a generic confusion matrix (filled with TP, TP, FP, FN instead of hard numbers) to illustrate the underlying computations.