In this exercise we will explore how to implement the KNN algorithm for classification.
This data set consists of percentage returns for the S&P 500 stock index over 1, 250 days, from the beginning of 2001 until the end of 2005. For each date, we have recorded the percentage returns for each of the five previous trading days, Lag1 through Lag5. We have also recorded Volume (the number of shares traded on the previous day, in billions), Today (the percentage return on the date in question) and Direction (whether the market was Up or Down on this date).
library(ISLR)
## Warning: package 'ISLR' was built under R version 3.6.2
names(Smarket)
## [1] "Year" "Lag1" "Lag2" "Lag3" "Lag4" "Lag5"
## [7] "Volume" "Today" "Direction"
attach(Smarket)
#install.packages(class)
library(class)
## Warning: package 'class' was built under R version 3.6.2
Seperate out the 2005 data, which was the most current in the dataset. This is very true to how forecasting would be done in a business setting because we have the past data to forecast the future.
# YEAR 2005
train=(Year <2005)
Smarket.2005= Smarket [! train ,]
dim(Smarket.2005)
## [1] 252 9
Direction.2005=Direction[!train]
# Creating a variable for if the data is from the training set or testing set
Smarket$Year05<-"No"
Smarket$Year05[which(Smarket$Year==2005)]<-"Yes"
We only want to predict the stock market Direction based on Lag1 and Lag2. Lets create a plot to visualize this! Remember we want to only consider the closest neighbors (this uses Euclidean distance!)
library(tidyverse)
ggplot(Smarket, aes(Lag1, Lag2, color=Direction, pch=Year05))+
geom_jitter(alpha=0.5)+
theme_bw()
train.X=cbind(Lag1 ,Lag2)[train ,]
test.X=cbind(Lag1,Lag2)[!train,]
train.Direction = Direction[train]
Implement the KNN algorithm!
set.seed(1)
knn.pred=knn(train.X,test.X,train.Direction ,k=1)
#Confusion matrix
table(knn.pred,Direction.2005)
## Direction.2005
## knn.pred Down Up
## Down 43 58
## Up 68 83
knn.pred=knn(train.X,test.X,train.Direction ,k=3)
#Confusion matrix
table(knn.pred,Direction.2005)
## Direction.2005
## knn.pred Down Up
## Down 48 54
## Up 63 87
mean(knn.pred==Direction.2005)
## [1] 0.5357143
detach(Smarket)
The iris
dataset is a popular dataset used for statistical examples. The documentation for this dataset in R states:
This famous (Fisher’s or Anderson’s) iris data set gives the measurements in centimeters of the variables sepal length and width and petal length and width, respectively, for 50 flowers from each of 3 species of iris. The species are Iris setosa, versicolor, and virginica.
Our goal is to be able to predict the species of an unknown iris flower, given measurements for four physical characteristics.
head(iris)
## Sepal.Length Sepal.Width Petal.Length Petal.Width Species
## 1 5.1 3.5 1.4 0.2 setosa
## 2 4.9 3.0 1.4 0.2 setosa
## 3 4.7 3.2 1.3 0.2 setosa
## 4 4.6 3.1 1.5 0.2 setosa
## 5 5.0 3.6 1.4 0.2 setosa
## 6 5.4 3.9 1.7 0.4 setosa
Since we are using a distance metric to assess closeness of neighbors, it is essential to standardize our data.
# unstandardized
var(iris[ ,1])
## [1] 0.6856935
var(iris[ ,2])
## [1] 0.1899794
# use scale to standardize
irisS<-iris
irisS[,1:4] <-scale(irisS[,1:4])
irisS<-data.frame(irisS)
# now check the standardized variance
var(irisS[ ,1])
## [1] 1
var(irisS[ ,2])
## [1] 1
Looking at the data do there appear to be any clear groupings?
### Are there groups?
library(GGally)
## Warning: package 'GGally' was built under R version 3.6.2
## Registered S3 method overwritten by 'GGally':
## method from
## +.gg ggplot2
ggpairs(irisS)
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## explore
library(tidyverse)
ggplot(data = irisS, aes(x = Sepal.Length, y = Sepal.Width, col = Species)) +
geom_point()
ggplot(data = irisS, aes(x = Petal.Length, y = Petal.Width, col = Species)) +
geom_point()
We want to make sure that we have balanced representation in the training and testing sets.
### train and test
set.seed(239)
setosa<- irisS%>%
filter(Species=="setosa")
versicolor<- irisS%>%
filter(Species=="versicolor")
virginica<- irisS%>%
filter(Species=="virginica")
# 50 observations from each
# Take 60% to train (30) and 40% to test (20)
train <- sample(1:50, 30)
iris.train<- rbind(setosa[train,], versicolor[train,], virginica[train,])
iris.test<- rbind(setosa[-train,], versicolor[-train,], virginica[-train,])
library(class)
knnIris<- knn(train = iris.train[,1:4],
test = iris.test[,1:4],
cl = iris.train$Species,
k = 3)
## correct
mean(knnIris==iris.test$Species)
## [1] 0.9833333
## error
mean(knnIris!=iris.test$Species)
## [1] 0.01666667
## Pick a neighborhood
error <- c()
for (i in 1:15)
{
knnIris<- knn(train = iris.train[,1:4], test = iris.test[,1:4], cl = iris.train$Species, k = i)
error[i] = 1- mean(knnIris == iris.test$Species)
}
ggplot(data = data.frame(error), aes(x = 1:15, y = error)) +
geom_line(color = "Blue")+
xlab("Neighborhood Size")
iris_pred <- knn(train = iris.train[,1:4],
test = iris.test[,1:4],
cl = iris.train$Species,
k=3)
table(iris.test$Species, iris_pred)
## iris_pred
## setosa versicolor virginica
## setosa 20 0 0
## versicolor 0 20 0
## virginica 0 1 19
Caravan data set, which is part of the ISLR library. This data set includes 85 predictors that measure demographic characteristics for 5,822 individuals. The response variable is Purchase, which indicates whether or not a given individual purchases a caravan insurance policy. In this data set, only 6% of people purchased caravan insurance.
### Application: Caravan Insurance
library(ISLR)
data("Caravan")
dim(Caravan)
## [1] 5822 86
attach(Caravan)
summary(Purchase)
## No Yes
## 5474 348
# Percent that purchased
348/dim(Caravan)[1]
## [1] 0.05977327
Since the KNN algorithm uses a distance metric, its VERY important to standardize the variables first. For instance, consider the variables salary and age. Those are in completely different scale! The variability of salary is much more than age and the importance of the age variable could get lost!
So first, we should start with standardizing the variables.
# KNN uses a distance metric
# so standardizing the units is VERY important!
# consider salary and age
# take out 86th col b/c Purchase variable
standardized.X=scale(Caravan [,-86])
# unstandardized
var(Caravan[ ,1])
## [1] 165.0378
var(Caravan[ ,2])
## [1] 0.1647078
# standardized
var(standardized.X[ ,1])
## [1] 1
var(standardized.X[ ,2])
## [1] 1
Since we want to predict the insurance sales, let’s train and then test the model.
# Split data into test and train
test=1:1000
train.X=standardized.X[-test ,]
test.X=standardized.X[test ,]
train.Y=Purchase[-test]
test.Y=Purchase[test]
# K=1
set.seed (1)
knn.pred1=knn(train.X,test.X,train.Y,k=1)
#Error rate
mean(test.Y!=knn.pred1)
## [1] 0.118
mean(test.Y!="No")
## [1] 0.059
# confusion matrix
table(knn.pred1, test.Y)
## test.Y
## knn.pred1 No Yes
## No 873 50
## Yes 68 9
# Rate correct.. better than random guess
9/(68+9)
## [1] 0.1168831
# K=3
knn.pred3=knn(train.X,test.X,train.Y,k=3)
table(knn.pred3, test.Y)
## test.Y
## knn.pred3 No Yes
## No 920 54
## Yes 21 5
5/26
## [1] 0.1923077
# K=5
knn.pred3=knn(train.X,test.X,train.Y,k=5)
table(knn.pred3, test.Y)
## test.Y
## knn.pred3 No Yes
## No 930 55
## Yes 11 4
4/15
## [1] 0.2666667
Disclaimers:
Examples 1 and 3 are based on the R code provided in “Introduction to Statistical Learning”
Example 2 is based on “KNN Classification on Iris Data” by Devanshu Awasthi