This session1 Part of Introduction to Statistical Learning in R
Supervised Machine Learning – Tree Regressions, Random Forest & Cross-validation by Francisco Rowe is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. provides an introduction to supervised machine learning.
Machine learning traces back to Samuel (1959Samuel, Arthur. 1959. “Some Studies in Machine Learning Using the Game of Checkers.” IBM Journal of Research and Development 44: 206–26.) and it is conceived as a subset of artificial intelligence. Machine learning algorithms build mathematical models based on sample data, in order to make predictions or decisions without being explicitly programmed to perform the task.
\[\hat{y} = \hat{f}(x) \]
Prediction: We are not concerned with the exact form of \(\hat{f}\). It is treated as a black box. The interest is in accurate predictions.
Inference: We are interested in the form of \(\hat{f}\) ie. the way \(y\) is affected as \(x\) changes. The goal is not necessarily to make predictions.
Machine learning revolves around the problem of prediction: produce predictions of \(y\) from \(x\). They are not built to understand the underpinning relationship between \(y\) and \(x\).
Supervised: For each observation \(i\) of the explanatory variable \(x\), there is an associated response \(y\).
Unsupervised: For each observation \(i\) we observe measurements \(x_{i}\) but no associated response \(y\).
Regression problems ~ quantitative response variable
Classification problems ~ qualitative response variable
Regression trees partition a data set into smaller subgroups and then fit a simple constant for each observation in the subgroup. The partitioning is achieved by successive binary partitions based on the different predictors. Like for linear regression, these partitions are chosen minimising the \(RSS\) (ie. minimising the total difference between the observed and predicted \(y\)). The constant to predict is based on the average response values for all observations that fall in that subgroup. To these ends, we train a model on a subset of our data in order to predict an unseen set of the sample.
An example: We want to predict individual’s net pay income based on age (as a measured of experience) and sex (as a measure of gender discrimination in the labour market). We first remove observations with missing and negative income values and log-transform NetPay in our QLFS data so its distribution has a more typical bell-shape.
# clean workspace
rm(list=ls())
# load data
load("../data/data_qlfs.RData")
# select variables to create a new df
data <- qlfs %>%
dplyr::select(NetPay, Sex, Age) %>%
filter(!is.na(NetPay)) %>%
filter(NetPay > 0)
# log-transform NetPay
data$NetPay <- log(data$NetPay)
# remove qlfs from the workspace
rm(qlfs)
set.seed()
for reproducibility# create an index variable to identify a training set
set.seed(123)
data_train <- sample(
1:nrow(data),
round(0.7* (nrow(data)))
)
# fit regression tree
m1_tree <- tree(
formula = NetPay~.,
data = data,
subset = data_train
)
summary(m1_tree)
##
## Regression tree:
## tree(formula = NetPay ~ ., data = data, subset = data_train)
## Number of terminal nodes: 5
## Residual mean deviance: 0.418 = 3203 / 7663
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -5.43800 -0.33630 0.04319 0.00000 0.40800 2.25600
Replicating Tree above in Fig.2.
# plot the tree
plot(m1_tree)
text(m1_tree, pretty =0)
The process described above may produce good predictions on the training set, but it is likely to overfit the data, leading to poor test set performance. This is because the resulting tree might be too complex.
The key idea is to minimise the expected test error. Find a smaller tree with fewer splits that lead to lower variance and better interpretation at the cost of little bias ie. the right balance.
To achieve a good balance between variance and bias, a common approach is cost complexity prunning: grow a very long tree and then prune it back to obtain a subtree; that is, reduce the complexity of the tree finding the subtree that returns the lowest test error rate. We need to find the tuning parameter \(a\) that penalises the tree complexity.
cv_data <- cv.tree(m1_tree)
plot(cv_data$size, cv_data$dev, type = 'b')
Fig.3. Pruned tree using 3 branches
prune_data <- prune.tree(m1_tree, best=3)
plot(prune_data)
text(prune_data, pretty= 0)
To make predictions, we use the unpruned tree on the test dataset
# predict net pay based on the training dataset
yhat <- predict(m1_tree, newdata= data[-data_train,])
# create indices for the test set
data_test= data[-data_train, "NetPay"]
# compute the mean squared error (MSE)
mse = mean((yhat- data_test)^2)
mse
## [1] 0.439887
# take the square root
sqrt(mse)
## [1] 0.6632398
Interpretation: the test set MSE associated with the regression tree is 0.43. The square root of the MSE is therefore around 0.663, indicating that this model leads to average test predictions that are within around £0.663 of the true log of weekly net pay value.
Decision trees suffer from high variance; that is, if we split the training data into two parts at random, and fit a decision tree to both halves, the results that we get could be quite different. Random Forest is well suited to reduce variance.
To take various training datasets, build separate prediction models and average the resulting predictions, in order to reduce the variance and increase the accuracy of predictions. Random Forest does this by using a random sample of predictors at each split in a tree. This enables decorrelating the trees.
# clean workspace
rm(list=ls())
# read
load("../data/data_qlfs.RData")
# select variables to create a new df
data <- qlfs %>% dplyr::select(NetPay, FamilySize, Age, Sex, MaritalStatus, NSSEC, EthnicGroup, HighestQual, Tenure, TravelTime) %>%
filter(!is.na(NetPay)) %>%
filter(NetPay > 0)
# log-transform NetPay
data$NetPay <- log(data$NetPay)
# remove qlfs from the workspace
rm(qlfs)
set.seed()
for reproducibility# create an index variable to identify a training set
set.seed(123)
data_train = sample(
1:nrow(data),
round(0.7* (nrow(data)))
)
# fit random forest
m1_rf <- randomForest(
formula = NetPay ~ .,
data = data,
subset = data_train,
mtry = 3,
importance = TRUE,
na.action=na.exclude
)
m1_rf
##
## Call:
## randomForest(formula = NetPay ~ ., data = data, mtry = 3, importance = TRUE, subset = data_train, na.action = na.exclude)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 3
##
## Mean of squared residuals: 0.2721213
## % Var explained: 43.73
# predict net pay based on the training dataset
yhat_rf <- predict(m1_rf, newdata= data[-data_train, ])
# create indices for the test set
data_test <- data[-data_train, "NetPay"]
# compute the mean squared error (MSE)
mse <- mean((!is.na(yhat_rf) - data_test)^2)
mse
## [1] 0.0003043214
#Conclusion: The test set MSE is 0.0003; this indicates that random forests yielded an improvement over regression trees.
We can assess the importance of each variable. We can use two measures of variable importance. The IncMSE below is based on the mean decrease of prediction accuracy on the out-of-bag samples when a variable is excluded from the model. The IncNodePurity is a measure of the total decrease in node impurity that results from splits over that variable, averaged over all trees. For regression trees,node impurity is measured by the training RSS, and for classification trees by the deviance.
varImpPlot(m1_rf)
Test Error: Average error of predicting an observation based on a unseen test dataset i.e. not used in the training dataset.
Tranning Error: Average error of predicting an observation included in the tranning dataset.
An example:
Read data and create a training sample
# read data
Auto <- Auto
# create training sample
set.seed (1)
train= sample(392 ,196)
Run a 10-fold cross-validation
set.seed (17)
cv.error.10 = rep (0 ,10)
for (i in 1:10) {
glm.fit <- glm(mpg ~ poly (horsepower, i), data = Auto)
cv.error.10[i] = cv.glm(Auto, glm.fit, K = 10) $delta[1]
}
cv.error.10
## [1] 24.27207 19.26909 19.34805 19.29496 19.03198 18.89781 19.12061 19.14666
## [9] 18.87013 20.95520
Function | Description |
---|---|
sample() | create an index vector for a sample |
set.seed() | set a seed number |
tree() | fit a regression tree |
plot(), text() | plot a tree |
cv.tree() | preform k-fold cross-validation for trees |
prune.tree() | prune a tree |
predict() | generate predictions |
cv.glm() | run cross-validation |