Logistic regression is analogues to linear regression for binary outcomes. In this example I will estimate whether a customer churned (yes/no) depending on a set of predictors. The dataset used is from kaggle, describing Telco Customer Churn.
glimpse(data)
## Rows: 7,043
## Columns: 21
## $ customerID <fct> 7590-VHVEG, 5575-GNVDE, 3668-QPYBK, 7795-CFOCW, 9237-…
## $ gender <fct> Female, Male, Male, Male, Female, Female, Male, Femal…
## $ SeniorCitizen <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
## $ Partner <fct> Yes, No, No, No, No, No, No, No, Yes, No, Yes, No, Ye…
## $ Dependents <fct> No, No, No, No, No, No, Yes, No, No, Yes, Yes, No, No…
## $ tenure <int> 1, 34, 2, 45, 2, 8, 22, 10, 28, 62, 13, 16, 58, 49, 2…
## $ PhoneService <fct> No, Yes, Yes, No, Yes, Yes, Yes, No, Yes, Yes, Yes, Y…
## $ MultipleLines <fct> No phone service, No, No, No phone service, No, Yes, …
## $ InternetService <fct> DSL, DSL, DSL, DSL, Fiber optic, Fiber optic, Fiber o…
## $ OnlineSecurity <fct> No, Yes, Yes, Yes, No, No, No, Yes, No, Yes, Yes, No …
## $ OnlineBackup <fct> Yes, No, Yes, No, No, No, Yes, No, No, Yes, No, No in…
## $ DeviceProtection <fct> No, Yes, No, Yes, No, Yes, No, No, Yes, No, No, No in…
## $ TechSupport <fct> No, No, No, Yes, No, No, No, No, Yes, No, No, No inte…
## $ StreamingTV <fct> No, No, No, No, No, Yes, Yes, No, Yes, No, No, No int…
## $ StreamingMovies <fct> No, No, No, No, No, Yes, No, No, Yes, No, No, No inte…
## $ Contract <fct> Month-to-month, One year, Month-to-month, One year, M…
## $ PaperlessBilling <fct> Yes, No, Yes, No, Yes, Yes, Yes, No, Yes, No, Yes, No…
## $ PaymentMethod <fct> Electronic check, Mailed check, Mailed check, Bank tr…
## $ MonthlyCharges <dbl> 29.85, 56.95, 53.85, 42.30, 70.70, 99.65, 89.10, 29.7…
## $ TotalCharges <dbl> 29.85, 1889.50, 108.15, 1840.75, 151.65, 820.50, 1949…
## $ Churn <fct> No, No, Yes, No, Yes, Yes, No, No, Yes, No, No, No, N…
This data contains 11 NA values.
any(is.na(data))
## [1] TRUE
is.na(data) %>% sum
## [1] 11
visdat::vis_miss(data)
Upon further inspection, all these NA values are found in the TotalCharges column. The total charges are undefined when the tenure (number of months the customer has been with the company) is zero.
data %>%
filter(is.na(TotalCharges)) %>%
DT::datatable(class = 'cell-border stripe',
rownames = FALSE,
options = list(dom = 'tp',
scrollX = TRUE,
pageLength = 6,
autoWidth = TRUE,
columnDefs = list(list(className = "dt-center", targets = 0:20))))
These are the only zero tenure values in the entire dataset. The missingness is therefore informative and can be replaced with zero values.
data <- mutate(data, TotalCharges=replace(TotalCharges, is.na(TotalCharges), 0))
The customer ID column is a unique identifier which is not used in the prediction process, it is therefore removed from the data.
data <- data %>%
select(-customerID)
Here is a complete rundown of the data, the phone service factor has a large imbalance.
Hmisc::describe(data)
## data
##
## 20 Variables 7043 Observations
## --------------------------------------------------------------------------------
## gender
## n missing distinct
## 7043 0 2
##
## Value Female Male
## Frequency 3488 3555
## Proportion 0.495 0.505
## --------------------------------------------------------------------------------
## SeniorCitizen
## n missing distinct Info Sum Mean Gmd
## 7043 0 2 0.408 1142 0.1621 0.2717
##
## --------------------------------------------------------------------------------
## Partner
## n missing distinct
## 7043 0 2
##
## Value No Yes
## Frequency 3641 3402
## Proportion 0.517 0.483
## --------------------------------------------------------------------------------
## Dependents
## n missing distinct
## 7043 0 2
##
## Value No Yes
## Frequency 4933 2110
## Proportion 0.7 0.3
## --------------------------------------------------------------------------------
## tenure
## n missing distinct Info Mean Gmd .05 .10
## 7043 0 73 0.999 32.37 28.08 1 2
## .25 .50 .75 .90 .95
## 9 29 55 69 72
##
## lowest : 0 1 2 3 4, highest: 68 69 70 71 72
## --------------------------------------------------------------------------------
## PhoneService
## n missing distinct
## 7043 0 2
##
## Value No Yes
## Frequency 682 6361
## Proportion 0.097 0.903
## --------------------------------------------------------------------------------
## MultipleLines
## n missing distinct
## 7043 0 3
##
## Value No No phone service Yes
## Frequency 3390 682 2971
## Proportion 0.481 0.097 0.422
## --------------------------------------------------------------------------------
## InternetService
## n missing distinct
## 7043 0 3
##
## Value DSL Fiber optic No
## Frequency 2421 3096 1526
## Proportion 0.344 0.440 0.217
## --------------------------------------------------------------------------------
## OnlineSecurity
## n missing distinct
## 7043 0 3
##
## Value No No internet service Yes
## Frequency 3498 1526 2019
## Proportion 0.497 0.217 0.287
## --------------------------------------------------------------------------------
## OnlineBackup
## n missing distinct
## 7043 0 3
##
## Value No No internet service Yes
## Frequency 3088 1526 2429
## Proportion 0.438 0.217 0.345
## --------------------------------------------------------------------------------
## DeviceProtection
## n missing distinct
## 7043 0 3
##
## Value No No internet service Yes
## Frequency 3095 1526 2422
## Proportion 0.439 0.217 0.344
## --------------------------------------------------------------------------------
## TechSupport
## n missing distinct
## 7043 0 3
##
## Value No No internet service Yes
## Frequency 3473 1526 2044
## Proportion 0.493 0.217 0.290
## --------------------------------------------------------------------------------
## StreamingTV
## n missing distinct
## 7043 0 3
##
## Value No No internet service Yes
## Frequency 2810 1526 2707
## Proportion 0.399 0.217 0.384
## --------------------------------------------------------------------------------
## StreamingMovies
## n missing distinct
## 7043 0 3
##
## Value No No internet service Yes
## Frequency 2785 1526 2732
## Proportion 0.395 0.217 0.388
## --------------------------------------------------------------------------------
## Contract
## n missing distinct
## 7043 0 3
##
## Value Month-to-month One year Two year
## Frequency 3875 1473 1695
## Proportion 0.550 0.209 0.241
## --------------------------------------------------------------------------------
## PaperlessBilling
## n missing distinct
## 7043 0 2
##
## Value No Yes
## Frequency 2872 4171
## Proportion 0.408 0.592
## --------------------------------------------------------------------------------
## PaymentMethod
## n missing distinct
## 7043 0 4
##
## Value Bank transfer (automatic) Credit card (automatic)
## Frequency 1544 1522
## Proportion 0.219 0.216
##
## Value Electronic check Mailed check
## Frequency 2365 1612
## Proportion 0.336 0.229
## --------------------------------------------------------------------------------
## MonthlyCharges
## n missing distinct Info Mean Gmd .05 .10
## 7043 0 1585 1 64.76 34.39 19.65 20.05
## .25 .50 .75 .90 .95
## 35.50 70.35 89.85 102.60 107.40
##
## lowest : 18.25 18.40 18.55 18.70 18.75, highest: 118.20 118.35 118.60 118.65 118.75
## --------------------------------------------------------------------------------
## TotalCharges
## n missing distinct Info Mean Gmd .05 .10
## 7043 0 6531 1 2280 2449 48.60 83.47
## .25 .50 .75 .90 .95
## 398.55 1394.55 3786.60 5973.69 6921.02
##
## lowest : 0.00 18.80 18.85 18.90 19.00
## highest: 8564.75 8594.40 8670.10 8672.45 8684.80
## --------------------------------------------------------------------------------
## Churn
## n missing distinct
## 7043 0 2
##
## Value No Yes
## Frequency 5174 1869
## Proportion 0.735 0.265
## --------------------------------------------------------------------------------
Next, the numeric data is reviewed using a correlation table. A strong positive correlation between the monthly and total charges is shown - the more monthly charges, the more the total cost. More interestingly though, senior citizen and tenure shows no correlation.
data %>%
select_if(is.numeric) %>%
cor %>%
corrplot::corrplot(type = "upper",
tl.col = "black",
method = "number")
The senior citizen variable is either 1 or 0. This must be recast as a factor before use in the model.
data$SeniorCitizen %>%
table %>%
barplot(col = "steelblue", main = "Senior Citizen")
The data is split into a training and testing set, stratified sampling is used on the imbalanced phone service variable.
data_split <- initial_split(data, prop = 3/4, strata = PhoneService)
data_train <- training(data_split)
data_test <- testing(data_split)
The recipe converts the senior citizen variable to a yes/no factor before dummy encoding all factors. Very little preprocessing is required for this dataset.
lr_recipe <- recipe(Churn ~ ., data_train) %>%
step_num2factor(SeniorCitizen,
transform = function(x){return(x + 1)},
levels = c("No", "Yes"),
ordered = F) %>%
step_dummy(all_nominal(), -all_outcomes())
A logistic regression model is selected from the glm library.
lr_model <-
logistic_reg(mode = "classification") %>%
set_engine("glm")
The workflow object ties the entire process together.
lr_wf <-
workflow() %>%
add_model(lr_model) %>%
add_recipe(lr_recipe)
The training data is fit to create the model and coefficients are returned.
(lr_wf_fit <-
lr_wf %>%
fit(data_train))
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: logistic_reg()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 2 Recipe Steps
##
## • step_num2factor()
## • step_dummy()
##
## ── Model ───────────────────────────────────────────────────────────────────────
##
## Call: stats::glm(formula = ..y ~ ., family = stats::binomial, data = data)
##
## Coefficients:
## (Intercept) tenure
## 0.9600053 -0.0573502
## MonthlyCharges TotalCharges
## -0.0318682 0.0002789
## gender_Male SeniorCitizen_Yes
## -0.0396671 0.2946431
## Partner_Yes Dependents_Yes
## -0.0751205 -0.1338094
## PhoneService_Yes MultipleLines_No.phone.service
## -0.0376969 NA
## MultipleLines_Yes InternetService_Fiber.optic
## 0.4225385 1.5559137
## InternetService_No OnlineSecurity_No.internet.service
## -1.5484668 NA
## OnlineSecurity_Yes OnlineBackup_No.internet.service
## -0.1979712 NA
## OnlineBackup_Yes DeviceProtection_No.internet.service
## 0.0012428 NA
## DeviceProtection_Yes TechSupport_No.internet.service
## 0.0441192 NA
## TechSupport_Yes StreamingTV_No.internet.service
## -0.1931264 NA
## StreamingTV_Yes StreamingMovies_No.internet.service
## 0.6293459 NA
## StreamingMovies_Yes Contract_One.year
## 0.5399053 -0.6815223
## Contract_Two.year PaperlessBilling_Yes
## -1.4071488 0.3228750
## PaymentMethod_Credit.card..automatic. PaymentMethod_Electronic.check
## -0.0001846 0.3308429
## PaymentMethod_Mailed.check
## 0.0498946
##
## Degrees of Freedom: 5281 Total (i.e. Null); 5258 Residual
## Null Deviance: 6228
## Residual Deviance: 4441 AIC: 4489
Using the model, churn is estimated from the test data. Here, the actual and predicted columns are bound for further analysis.
pred <-
predict(lr_wf_fit, data_test) %>%
bind_cols(., "truth" = data_test$Churn)
pred %>%
DT::datatable(class = 'cell-border stripe',
rownames = FALSE,
colnames = c("Prediction", "Truth"),
options = list(dom = 'tp',
pageLength = 10,
autoWidth = TRUE,
columnDefs = list(list(className = "dt-center", targets = 0:1),
list(width = '200px', targets = "_all"))))
A confusion matrix is an intuitive way to display the results. It allows one to determine the number of true/false positives/negatives.
pred %>%
conf_mat(truth, .pred_class) %>%
pluck(1) %>%
as_tibble() %>%
ggplot(data = ., aes(Prediction, Truth, alpha = n)) +
geom_tile(show.legend = F) +
geom_text(aes(label = n), colour = "white", alpha = 1, size = 8) +
theme_minimal() +
theme(panel.grid.major = element_blank()) +
scale_x_discrete(expand = c(0,0)) +
scale_y_discrete(expand = c(0,0))
The metrics function once again returns the model performance. As no custom metrics were requested, the defaults are returned; accuracy and Cohen’s kappa. More detail on these metrics can be found on the model performance page.
pred %>%
metrics(truth, .pred_class)
## # A tibble: 2 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.810
## 2 kap binary 0.451
Clearly, this model is not perfect. There are far too many false negatives! More complex models will be introduced in a later script.
sessionInfo()
## R version 4.2.2 (2022-10-31)
## Platform: aarch64-apple-darwin20 (64-bit)
## Running under: macOS Ventura 13.3
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] magrittr_2.0.3 yardstick_1.1.0 workflowsets_1.0.1 workflows_1.1.3
## [5] tune_1.1.1 tidyr_1.3.0 tibble_3.2.1 rsample_1.1.1
## [9] recipes_1.0.5 purrr_1.0.1 parsnip_1.1.0 modeldata_1.1.0
## [13] infer_1.0.4 ggplot2_3.4.1 dplyr_1.1.0 dials_1.2.0
## [17] scales_1.2.1 broom_1.0.4 tidymodels_1.0.0
##
## loaded via a namespace (and not attached):
## [1] lubridate_1.9.2 DiceDesign_1.9 tools_4.2.2
## [4] backports_1.4.1 bslib_0.4.2 utf8_1.2.3
## [7] R6_2.5.1 DT_0.27 rpart_4.1.19
## [10] Hmisc_5.0-1 colorspace_2.1-0 nnet_7.3-18
## [13] withr_2.5.0 tidyselect_1.2.0 gridExtra_2.3
## [16] compiler_4.2.2 cli_3.6.0 htmlTable_2.4.1
## [19] labeling_0.4.2 sass_0.4.5 checkmate_2.1.0
## [22] stringr_1.5.0 digest_0.6.31 foreign_0.8-83
## [25] rmarkdown_2.20 base64enc_0.1-3 pkgconfig_2.0.3
## [28] htmltools_0.5.4 parallelly_1.35.0 lhs_1.1.6
## [31] fastmap_1.1.1 highr_0.10 htmlwidgets_1.6.1
## [34] rlang_1.1.0 rstudioapi_0.14 jquerylib_0.1.4
## [37] generics_0.1.3 farver_2.1.1 jsonlite_1.8.4
## [40] crosstalk_1.2.0 Formula_1.2-5 Matrix_1.5-1
## [43] Rcpp_1.0.10 munsell_0.5.0 fansi_1.0.4
## [46] GPfit_1.0-8 lifecycle_1.0.3 furrr_0.3.1
## [49] visdat_0.6.0 stringi_1.7.12 yaml_2.3.7
## [52] MASS_7.3-58.1 grid_4.2.2 parallel_4.2.2
## [55] listenv_0.9.0 lattice_0.20-45 splines_4.2.2
## [58] knitr_1.42 pillar_1.8.1 future.apply_1.10.0
## [61] codetools_0.2-18 glue_1.6.2 evaluate_0.20
## [64] data.table_1.14.8 vctrs_0.6.1 foreach_1.5.2
## [67] gtable_0.3.1 future_1.32.0 cachem_1.0.7
## [70] xfun_0.37 gower_1.0.1 prodlim_2023.03.31
## [73] class_7.3-20 survival_3.4-0 timeDate_4022.108
## [76] iterators_1.0.14 hardhat_1.3.0 corrplot_0.92
## [79] cluster_2.1.4 lava_1.7.2.1 timechange_0.2.0
## [82] globals_0.16.2 ellipsis_0.3.2 ipred_0.9-14