Linear regression is the simplest machine learning concept one should know. On this page, I use linear regression to estimate the length of petals using the famous iris dataset and in doing so will layout the typical machine learning workflow.
data <- iris
The data consists of 5 columns, 4 numeric and 1 nominal. For this example I shall create a linear model to predict the petal length from the petal width. No data is missing and no cleaning is necessary for this dataset.
glimpse(data)
## Rows: 150
## Columns: 5
## $ Sepal.Length <dbl> 5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9, 5.4, 4.…
## $ Sepal.Width <dbl> 3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1, 3.7, 3.…
## $ Petal.Length <dbl> 1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5, 1.5, 1.…
## $ Petal.Width <dbl> 0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 0.2, 0.…
## $ Species <fct> setosa, setosa, setosa, setosa, setosa, setosa, setosa, s…
Data is first split into a testing and training set using the rsample package:
data_split <- initial_split(data, prop = 3/4)
data_train <- training(data_split)
data_test <- testing(data_split)
A linear model is selected from the lm library…
lr_model <-
linear_reg() %>%
set_engine("lm")
…and a very simple recipe is produced.
lr_recipe <- recipe(Petal.Length ~ Petal.Width, data_train)
A workflow object is used to combine the model and recipe. This also removes the need to prep and bake the recipe:
lr_wf <-
workflow() %>%
add_model(lr_model) %>%
add_recipe(lr_recipe)
The fit method then uses the training data to establish a model. The lr_fit object contains the model coefficients.
(lr_fit <- fit(lr_wf, data_train))
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: linear_reg()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 0 Recipe Steps
##
## ── Model ───────────────────────────────────────────────────────────────────────
##
## Call:
## stats::lm(formula = ..y ~ ., data = data)
##
## Coefficients:
## (Intercept) Petal.Width
## 1.089 2.237
Using this model we can now predict the length of petals using the test data.
(pred <- predict(lr_fit, data_test) %>%
select(.pred) %>%
bind_cols(., "truth" = data_test$Petal.Length,
"resid" = data_test$Petal.Length - .$.pred))
## # A tibble: 38 × 3
## .pred truth resid
## <dbl> <dbl> <dbl>
## 1 1.54 1.4 -0.137
## 2 1.54 1.5 -0.0367
## 3 1.54 1.2 -0.337
## 4 1.98 1.5 -0.484
## 5 1.76 1.5 -0.260
## 6 1.54 1.7 0.163
## 7 1.98 1.5 -0.484
## 8 1.98 1.5 -0.484
## 9 1.54 1.5 -0.0367
## 10 2.43 1.6 -0.831
## # … with 28 more rows
The performance of linear models are typically graded using rmse, rsq and mae values. These are obtained using the metrics method.
pred %>% metrics(truth = truth, estimate = .pred)
## # A tibble: 3 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 0.505
## 2 rsq standard 0.916
## 3 mae standard 0.397
A more intuitive visualisation can be obtained by plotting the predicted values against the actual values [left]. The diagonal line has a slope of 1 indicating an exact match. A good match is found between predicted and actual values. Additionally, linear regression requires homoscedasticity (constant variance amongst errors). This behaviour is displayed below [right] though few data points are available to be conclusive.
ggplt1 <-
ggplot(data = pred, aes(x = .pred, y = truth)) +
geom_point() +
geom_abline(slope = 1, col = "red") +
theme_linedraw() +
labs(x = "Predicted Value", y = "Actual Value") +
lims(x = c(0,9), y = c(0,8))
ggplt2 <-
ggplot(data = pred, aes(x = .pred, y = resid)) +
geom_point() +
geom_abline(slope = 0, intercept = 0, col = "red") +
theme_linedraw() +
labs(x = "Predicted Value", y = "Residuals") +
lims(x = c(0,9), y = c(-2,2))
gridExtra::grid.arrange(ggplt1, ggplt2, nrow = 1)
This simple introduction to linear regression highlights the
following tidymodels concepts:
1. workflow objects
2. recipe creation
3. model selection
4. data fitting
5. data prediction
6. model evaluation
sessionInfo()
## R version 4.2.2 (2022-10-31)
## Platform: aarch64-apple-darwin20 (64-bit)
## Running under: macOS Ventura 13.3
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] yardstick_1.1.0 workflowsets_1.0.1 workflows_1.1.3 tune_1.1.1
## [5] tidyr_1.3.0 tibble_3.2.1 rsample_1.1.1 recipes_1.0.5
## [9] purrr_1.0.1 parsnip_1.1.0 modeldata_1.1.0 infer_1.0.4
## [13] ggplot2_3.4.1 dplyr_1.1.0 dials_1.2.0 scales_1.2.1
## [17] broom_1.0.4 tidymodels_1.0.0 magrittr_2.0.3
##
## loaded via a namespace (and not attached):
## [1] sass_0.4.5 jsonlite_1.8.4 splines_4.2.2
## [4] foreach_1.5.2 prodlim_2023.03.31 bslib_0.4.2
## [7] highr_0.10 GPfit_1.0-8 yaml_2.3.7
## [10] globals_0.16.2 ipred_0.9-14 pillar_1.8.1
## [13] backports_1.4.1 lattice_0.20-45 glue_1.6.2
## [16] digest_0.6.31 hardhat_1.3.0 colorspace_2.1-0
## [19] htmltools_0.5.4 Matrix_1.5-1 timeDate_4022.108
## [22] pkgconfig_2.0.3 lhs_1.1.6 DiceDesign_1.9
## [25] listenv_0.9.0 gower_1.0.1 lava_1.7.2.1
## [28] timechange_0.2.0 farver_2.1.1 generics_0.1.3
## [31] ellipsis_0.3.2 cachem_1.0.7 withr_2.5.0
## [34] furrr_0.3.1 nnet_7.3-18 cli_3.6.0
## [37] survival_3.4-0 evaluate_0.20 future_1.32.0
## [40] fansi_1.0.4 parallelly_1.35.0 MASS_7.3-58.1
## [43] class_7.3-20 tools_4.2.2 data.table_1.14.8
## [46] lifecycle_1.0.3 munsell_0.5.0 compiler_4.2.2
## [49] jquerylib_0.1.4 rlang_1.1.0 grid_4.2.2
## [52] iterators_1.0.14 rstudioapi_0.14 labeling_0.4.2
## [55] rmarkdown_2.20 gtable_0.3.1 codetools_0.2-18
## [58] R6_2.5.1 gridExtra_2.3 lubridate_1.9.2
## [61] knitr_1.42 fastmap_1.1.1 future.apply_1.10.0
## [64] utf8_1.2.3 parallel_4.2.2 Rcpp_1.0.10
## [67] vctrs_0.6.1 rpart_4.1.19 tidyselect_1.2.0
## [70] xfun_0.37