Supervised Machine Learning

Tree Regressions, Random Forest & Cross-validation

Francisco Rowe

2020-11-11

This session1 Part of Introduction to Statistical Learning in R Creative Commons License
Supervised Machine Learning – Tree Regressions, Random Forest & Cross-validation by Francisco Rowe is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
provides an introduction to supervised machine learning.

1 What is Machine Learning?

Machine learning traces back to Samuel (1959Samuel, Arthur. 1959. “Some Studies in Machine Learning Using the Game of Checkers.” IBM Journal of Research and Development 44: 206–26.) and it is conceived as a subset of artificial intelligence. Machine learning algorithms build mathematical models based on sample data, in order to make predictions or decisions without being explicitly programmed to perform the task.

1.1 Inference vs. Prediction

\[\hat{y} = \hat{f}(x) \]

Prediction: We are not concerned with the exact form of \(\hat{f}\). It is treated as a black box. The interest is in accurate predictions.

Inference: We are interested in the form of \(\hat{f}\) ie. the way \(y\) is affected as \(x\) changes. The goal is not necessarily to make predictions.

Machine learning revolves around the problem of prediction: produce predictions of \(y\) from \(x\). They are not built to understand the underpinning relationship between \(y\) and \(x\).

1.2 Supervised vs Unsupervised Learning

Supervised: For each observation \(i\) of the explanatory variable \(x\), there is an associated response \(y\).

Unsupervised: For each observation \(i\) we observe measurements \(x_{i}\) but no associated response \(y\).

1.3 Prediction Accuracy vs. Model Interpretability

Fig.1. Interpretability vs. Flexibility. Source: James et al. (2013James, Gareth, Daniela Witten, Trevor Hastie, and Robert Tibshirani. 2013. An Introduction to Statistical Learning. Vol. 112. Springer.).

1.4 Regression vs. Classification

Regression problems ~ quantitative response variable

Classification problems ~ qualitative response variable

2 Regression Trees

2.1 The Idea

Regression trees partition a data set into smaller subgroups and then fit a simple constant for each observation in the subgroup. The partitioning is achieved by successive binary partitions based on the different predictors. Like for linear regression, these partitions are chosen minimising the \(RSS\) (ie. minimising the total difference between the observed and predicted \(y\)). The constant to predict is based on the average response values for all observations that fall in that subgroup. To these ends, we train a model on a subset of our data in order to predict an unseen set of the sample.

An example: We want to predict individual’s net pay income based on age (as a measured of experience) and sex (as a measure of gender discrimination in the labour market). We first remove observations with missing and negative income values and log-transform NetPay in our QLFS data so its distribution has a more typical bell-shape.

Fig.2. A regression tree for predicting the log net pay of individuals based on age and gender. At a given internal node, the label (\(X_{j} < t_P{k}\)) indicates the left hand branch emanating from that split, and the right-hand branch corresponds to \(X_{j} ≥ t_{k}\). The tree has three internal nodes and five terminal nodes, or leaves. The number in each leaf is the mean of the response for the observations. The terminal node indicates that the predicted net pay for males aged < 23.5 is 5.241

Fig.2. A regression tree for predicting the log net pay of individuals based on age and gender. At a given internal node, the label ($X_{j} < t_P{k}$) indicates the left hand branch emanating from that split, and the right-hand branch corresponds to $X_{j} ≥ t_{k}$. The tree has three internal nodes and five terminal nodes, or leaves. The number in each leaf is the mean of the response for the observations. The terminal node indicates that the predicted net pay for males aged < 23.5 is 5.241

2.2 Fitting Regression Trees

  1. Read the data and create a dummy dataset:
# clean workspace
rm(list=ls())

# load data
load("../data/data_qlfs.RData")

# select variables to create a new df
data <- qlfs %>% 
  dplyr::select(NetPay, Sex, Age) %>%
  filter(!is.na(NetPay)) %>%
  filter(NetPay > 0)

# log-transform NetPay
data$NetPay <- log(data$NetPay)

# remove qlfs from the workspace
rm(qlfs)
  1. Create training (70%) and test (30%) sets and use set.seed() for reproducibility
# create an index variable to identify a training set
set.seed(123)
data_train <- sample(
  1:nrow(data),
  round(0.7* (nrow(data)))
  )
  1. Fit a regression tree:
# fit regression tree
m1_tree <- tree(
  formula = NetPay~.,
  data    = data,
  subset = data_train
  )
summary(m1_tree)
## 
## Regression tree:
## tree(formula = NetPay ~ ., data = data, subset = data_train)
## Number of terminal nodes:  5 
## Residual mean deviance:  0.418 = 3203 / 7663 
## Distribution of residuals:
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
## -5.43800 -0.33630  0.04319  0.00000  0.40800  2.25600

2.3 Visualising the tree

Replicating Tree above in Fig.2. Replicating Tree above in Fig.2.

# plot the tree
plot(m1_tree)
text(m1_tree, pretty =0)

2.4 Tuning

The process described above may produce good predictions on the training set, but it is likely to overfit the data, leading to poor test set performance. This is because the resulting tree might be too complex.

The key idea is to minimise the expected test error. Find a smaller tree with fewer splits that lead to lower variance and better interpretation at the cost of little bias ie. the right balance.

To achieve a good balance between variance and bias, a common approach is cost complexity prunning: grow a very long tree and then prune it back to obtain a subtree; that is, reduce the complexity of the tree finding the subtree that returns the lowest test error rate. We need to find the tuning parameter \(a\) that penalises the tree complexity.

2.4.1 Cost Complexity Prunning

  1. Use k-fold cross validation to choose \(a\). Note k = 10 is the default.
cv_data <- cv.tree(m1_tree)
plot(cv_data$size, cv_data$dev, type = 'b')

  1. We now prune the tree:

    Fig.3. Pruned tree using 3 branches Fig.3. Pruned tree using 3 branches

prune_data <- prune.tree(m1_tree, best=3)
plot(prune_data)
text(prune_data, pretty= 0)

2.5 Prediction

To make predictions, we use the unpruned tree on the test dataset

# predict net pay based on the training dataset
yhat <- predict(m1_tree, newdata= data[-data_train,])

# create indices for the test set
data_test= data[-data_train, "NetPay"]

# compute the mean squared error (MSE)
mse = mean((yhat- data_test)^2)
mse
## [1] 0.439887
# take the square root
sqrt(mse)
## [1] 0.6632398

Interpretation: the test set MSE associated with the regression tree is 0.43. The square root of the MSE is therefore around 0.663, indicating that this model leads to average test predictions that are within around £0.663 of the true log of weekly net pay value.

3 Random Forest

Decision trees suffer from high variance; that is, if we split the training data into two parts at random, and fit a decision tree to both halves, the results that we get could be quite different. Random Forest is well suited to reduce variance.

3.1 The Idea

To take various training datasets, build separate prediction models and average the resulting predictions, in order to reduce the variance and increase the accuracy of predictions. Random Forest does this by using a random sample of predictors at each split in a tree. This enables decorrelating the trees.

3.2 Fitting a Random Forest

  1. Read the data and create a dummy dataset:
# clean workspace
rm(list=ls())

# read 
load("../data/data_qlfs.RData")

# select variables to create a new df
data <- qlfs %>% dplyr::select(NetPay, FamilySize, Age, Sex, MaritalStatus, NSSEC, EthnicGroup, HighestQual, Tenure, TravelTime)  %>%
  filter(!is.na(NetPay)) %>%
  filter(NetPay > 0)

# log-transform NetPay
data$NetPay <- log(data$NetPay)

# remove qlfs from the workspace
rm(qlfs)
  1. Create training (70%) and test (30%) sets and use set.seed() for reproducibility
# create an index variable to identify a training set
set.seed(123)
data_train = sample(
  1:nrow(data),
  round(0.7* (nrow(data)))
  )
  1. Fit a random forest using 9 variables (usually \(p/3\)) at each split:
# fit random forest
m1_rf <- randomForest(
  formula = NetPay ~ .,
  data    = data,
  subset = data_train,
  mtry = 3, 
  importance = TRUE,
  na.action=na.exclude
  )
m1_rf
## 
## Call:
##  randomForest(formula = NetPay ~ ., data = data, mtry = 3, importance = TRUE,      subset = data_train, na.action = na.exclude) 
##                Type of random forest: regression
##                      Number of trees: 500
## No. of variables tried at each split: 3
## 
##           Mean of squared residuals: 0.2721213
##                     % Var explained: 43.73

3.3 Prediction

# predict net pay based on the training dataset
yhat_rf <-  predict(m1_rf, newdata= data[-data_train, ])

# create indices for the test set
data_test <- data[-data_train, "NetPay"]

# compute the mean squared error (MSE)
mse <- mean((!is.na(yhat_rf) - data_test)^2)
mse
## [1] 0.0003043214

#Conclusion: The test set MSE is 0.0003; this indicates that random forests yielded an improvement over regression trees.

3.4 Importance

We can assess the importance of each variable. We can use two measures of variable importance. The IncMSE below is based on the mean decrease of prediction accuracy on the out-of-bag samples when a variable is excluded from the model. The IncNodePurity is a measure of the total decrease in node impurity that results from splits over that variable, averaged over all trees. For regression trees,node impurity is measured by the training RSS, and for classification trees by the deviance.

varImpPlot(m1_rf)

4 Cross-validation

Test Error: Average error of predicting an observation based on a unseen test dataset i.e. not used in the training dataset.

Tranning Error: Average error of predicting an observation included in the tranning dataset.