Linear regression is the simplest machine learning concept one should know. On this page, I use linear regression to estimate the length of petals using the famous iris dataset and in doing so will layout the typical machine learning workflow.

data <- iris

The data consists of 5 columns, 4 numeric and 1 nominal. For this example I shall create a linear model to predict the petal length from the petal width. No data is missing and no cleaning is necessary for this dataset.

glimpse(data)
## Rows: 150
## Columns: 5
## $ Sepal.Length <dbl> 5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9, 5.4, 4.…
## $ Sepal.Width  <dbl> 3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1, 3.7, 3.…
## $ Petal.Length <dbl> 1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5, 1.5, 1.…
## $ Petal.Width  <dbl> 0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 0.2, 0.…
## $ Species      <fct> setosa, setosa, setosa, setosa, setosa, setosa, setosa, s…

Data is first split into a testing and training set using the rsample package:

data_split <- initial_split(data, prop = 3/4)
data_train <- training(data_split)
data_test <- testing(data_split)

A linear model is selected from the lm library…

lr_model <- 
  linear_reg() %>% 
  set_engine("lm")

…and a very simple recipe is produced.

lr_recipe <- recipe(Petal.Length ~ Petal.Width, data_train)

A workflow object is used to combine the model and recipe. This also removes the need to prep and bake the recipe:

lr_wf <- 
  workflow() %>% 
  add_model(lr_model) %>% 
  add_recipe(lr_recipe)

The fit method then uses the training data to establish a model. The lr_fit object contains the model coefficients.

(lr_fit <- fit(lr_wf, data_train))
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: linear_reg()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 0 Recipe Steps
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## 
## Call:
## stats::lm(formula = ..y ~ ., data = data)
## 
## Coefficients:
## (Intercept)  Petal.Width  
##       1.089        2.237

Using this model we can now predict the length of petals using the test data.

(pred <- predict(lr_fit, data_test) %>% 
  select(.pred) %>% 
  bind_cols(., "truth" = data_test$Petal.Length, 
            "resid" = data_test$Petal.Length - .$.pred))
## # A tibble: 38 × 3
##    .pred truth   resid
##    <dbl> <dbl>   <dbl>
##  1  1.54   1.4 -0.137 
##  2  1.54   1.5 -0.0367
##  3  1.54   1.2 -0.337 
##  4  1.98   1.5 -0.484 
##  5  1.76   1.5 -0.260 
##  6  1.54   1.7  0.163 
##  7  1.98   1.5 -0.484 
##  8  1.98   1.5 -0.484 
##  9  1.54   1.5 -0.0367
## 10  2.43   1.6 -0.831 
## # … with 28 more rows

The performance of linear models are typically graded using rmse, rsq and mae values. These are obtained using the metrics method.

pred %>%  metrics(truth = truth, estimate = .pred)
## # A tibble: 3 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard       0.505
## 2 rsq     standard       0.916
## 3 mae     standard       0.397

A more intuitive visualisation can be obtained by plotting the predicted values against the actual values [left]. The diagonal line has a slope of 1 indicating an exact match. A good match is found between predicted and actual values. Additionally, linear regression requires homoscedasticity (constant variance amongst errors). This behaviour is displayed below [right] though few data points are available to be conclusive.

ggplt1 <- 
  ggplot(data = pred, aes(x = .pred, y = truth)) +
    geom_point() +
    geom_abline(slope = 1, col = "red") +
    theme_linedraw() +
    labs(x = "Predicted Value", y = "Actual Value") +
    lims(x = c(0,9), y = c(0,8))

ggplt2 <- 
  ggplot(data = pred, aes(x = .pred, y = resid)) +
    geom_point() +
    geom_abline(slope = 0, intercept = 0, col = "red") +
    theme_linedraw() +
    labs(x = "Predicted Value", y = "Residuals") +
    lims(x = c(0,9), y = c(-2,2))

gridExtra::grid.arrange(ggplt1, ggplt2, nrow = 1)

This simple introduction to linear regression highlights the following tidymodels concepts:
1. workflow objects
2. recipe creation
3. model selection
4. data fitting
5. data prediction
6. model evaluation

Meta

sessionInfo()
## R version 4.2.2 (2022-10-31)
## Platform: aarch64-apple-darwin20 (64-bit)
## Running under: macOS Ventura 13.3
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] yardstick_1.1.0    workflowsets_1.0.1 workflows_1.1.3    tune_1.1.1        
##  [5] tidyr_1.3.0        tibble_3.2.1       rsample_1.1.1      recipes_1.0.5     
##  [9] purrr_1.0.1        parsnip_1.1.0      modeldata_1.1.0    infer_1.0.4       
## [13] ggplot2_3.4.1      dplyr_1.1.0        dials_1.2.0        scales_1.2.1      
## [17] broom_1.0.4        tidymodels_1.0.0   magrittr_2.0.3    
## 
## loaded via a namespace (and not attached):
##  [1] sass_0.4.5          jsonlite_1.8.4      splines_4.2.2      
##  [4] foreach_1.5.2       prodlim_2023.03.31  bslib_0.4.2        
##  [7] highr_0.10          GPfit_1.0-8         yaml_2.3.7         
## [10] globals_0.16.2      ipred_0.9-14        pillar_1.8.1       
## [13] backports_1.4.1     lattice_0.20-45     glue_1.6.2         
## [16] digest_0.6.31       hardhat_1.3.0       colorspace_2.1-0   
## [19] htmltools_0.5.4     Matrix_1.5-1        timeDate_4022.108  
## [22] pkgconfig_2.0.3     lhs_1.1.6           DiceDesign_1.9     
## [25] listenv_0.9.0       gower_1.0.1         lava_1.7.2.1       
## [28] timechange_0.2.0    farver_2.1.1        generics_0.1.3     
## [31] ellipsis_0.3.2      cachem_1.0.7        withr_2.5.0        
## [34] furrr_0.3.1         nnet_7.3-18         cli_3.6.0          
## [37] survival_3.4-0      evaluate_0.20       future_1.32.0      
## [40] fansi_1.0.4         parallelly_1.35.0   MASS_7.3-58.1      
## [43] class_7.3-20        tools_4.2.2         data.table_1.14.8  
## [46] lifecycle_1.0.3     munsell_0.5.0       compiler_4.2.2     
## [49] jquerylib_0.1.4     rlang_1.1.0         grid_4.2.2         
## [52] iterators_1.0.14    rstudioapi_0.14     labeling_0.4.2     
## [55] rmarkdown_2.20      gtable_0.3.1        codetools_0.2-18   
## [58] R6_2.5.1            gridExtra_2.3       lubridate_1.9.2    
## [61] knitr_1.42          fastmap_1.1.1       future.apply_1.10.0
## [64] utf8_1.2.3          parallel_4.2.2      Rcpp_1.0.10        
## [67] vctrs_0.6.1         rpart_4.1.19        tidyselect_1.2.0   
## [70] xfun_0.37