library(rsample) # data splitting
library(gbm) # basic implementation
library(xgboost) # a faster implementation of gbm
library(caret) # an aggregator package for performing many machine learning
## load data
library(tidyverse)
library(ISLR)
ml_data <- College
ml_data[1:5,1:5]
dim(ml_data)
# Partition into training and test data
set.seed(42)
index <- createDataPartition(ml_data$Private, p = 0.7, list = FALSE)
train_data <- ml_data[index, ]
test_data <- ml_data[-index, ]
# Train model with preprocessing & repeated cv
model_gbm <- caret::train(Private ~ .,
data = train_data,
method = "gbm",
trControl = trainControl(method = "repeatedcv",
number = 5,
repeats = 3,
verboseIter = FALSE),
verbose = 0)
model_gbm
## test
caret::confusionMatrix(data = predict(model_gbm, test_data),
reference = test_data$Private)
trctrl <- trainControl(method = "cv", number = 5)
tune_grid <- expand.grid(nrounds = 140:150,
max_depth = 5,
eta = 0.05,
gamma = 0.01,
colsample_bytree = 0.75,
min_child_weight = 0,
subsample = 0.5)
xgb.model <- train(Private ~ .,
data = train_data,
method = "xgbTree",
trControl=trctrl,
tuneGrid = tune_grid,
tuneLength = 10)
# have a look at the model
xgb.model
# Testing
test_predict <- predict(xgb.model, test_data)
caret::confusionMatrix(data = test_predict,
reference = test_data$Private)