rm(list=ls())
library(ISLR)
library(tree)

attach(Carseats)
head(Carseats)
##   Sales CompPrice Income Advertising Population Price ShelveLoc Age
## 1  9.50       138     73          11        276   120       Bad  42
## 2 11.22       111     48          16        260    83      Good  65
## 3 10.06       113     35          10        269    80    Medium  59
## 4  7.40       117    100           4        466    97    Medium  55
## 5  4.15       141     64           3        340   128       Bad  38
## 6 10.81       124    113          13        501    72       Bad  78
##   Education Urban  US
## 1        17   Yes Yes
## 2        10   Yes Yes
## 3        12   Yes Yes
## 4        14   Yes Yes
## 5        13   Yes  No
## 6        16    No Yes
range(Sales)
## [1]  0.00 16.27
#creat a categorical varibale based on sales

High=ifelse(Sales>=8,"Yes","No")
Carseats=data.frame(Carseats, High)
head(Carseats)
##   Sales CompPrice Income Advertising Population Price ShelveLoc Age
## 1  9.50       138     73          11        276   120       Bad  42
## 2 11.22       111     48          16        260    83      Good  65
## 3 10.06       113     35          10        269    80    Medium  59
## 4  7.40       117    100           4        466    97    Medium  55
## 5  4.15       141     64           3        340   128       Bad  38
## 6 10.81       124    113          13        501    72       Bad  78
##   Education Urban  US High
## 1        17   Yes Yes  Yes
## 2        10   Yes Yes  Yes
## 3        12   Yes Yes  Yes
## 4        14   Yes Yes   No
## 5        13   Yes  No   No
## 6        16    No Yes  Yes
#remove first varilabe 
Carseats=Carseats[,-1]
names(Carseats)
##  [1] "CompPrice"   "Income"      "Advertising" "Population"  "Price"      
##  [6] "ShelveLoc"   "Age"         "Education"   "Urban"       "US"         
## [11] "High"
#split data into training ans test set

set.seed(2)
train=sample(1:nrow(Carseats),nrow(Carseats)/2)
test=-train
test
##   [1]  -74 -281 -229  -67 -374 -373  -51 -328 -184 -216 -391  -93 -296  -70
##  [15] -157 -329 -375  -87 -170  -29 -252 -147 -317  -57 -131 -392  -56 -134
##  [29] -359  -50   -4  -61 -299 -319 -189 -398 -308 -104 -242  -55 -354 -107
##  [43]  -42  -59 -337 -283 -346 -124 -177 -285   -3   -6 -238 -323  -96 -399
##  [57] -271 -340 -210 -243 -262 -301 -212  -88 -289 -379 -130 -154  -73  -22
##  [71]  -91 -103  -14 -369  -60 -246  -94 -345 -334 -384 -113 -215   -8 -128
##  [85]  -64 -270 -306 -102 -365 -106 -303 -123 -118 -173 -142 -327 -322 -381
##  [99]  -35 -133 -305 -378 -293 -247  -85 -176 -265 -339  -44  -38 -318 -213
## [113] -108 -165 -237 -232 -248  -32 -269 -160  -11  -69 -273 -325 -397 -209
## [127] -155  -84 -366  -92 -396 -284   -7 -258 -388 -352 -141 -390  -49 -105
## [141]  -47 -400 -163  -78 -114 -187 -382  -79 -121 -259  -46  -90 -225  -98
## [155] -192 -387 -275 -358  -20  -71 -202 -226 -280 -180 -241 -153 -362 -228
## [169] -370 -125 -223  -24 -357 -204 -356 -179 -253  -31 -145 -380 -208 -126
## [183] -220 -254  -26 -286  -19 -347 -393  -17  -23 -266 -203 -302 -363 -341
## [197]  -41  -37 -315  -27
training_data=Carseats[train,]
testing_data=Carseats[test, ]
testing_High=High[test]

#fit thr tree model using training data

tree_model=tree(High~.,training_data)

plot(tree_model)
text(tree_model, pretty = 0)

#check how the model is doing using the test data

tree_pred=predict(tree_model, testing_data, type="class")
mean(tree_pred!=testing_High)
## [1] 0.285
#PRUNE the tree
##cross validation to check whre to stop pruning

set.seed(3)

cv_tree=cv.tree(tree_model, FUN=prune.misclass)
names(cv_tree)
## [1] "size"   "dev"    "k"      "method"
plot(cv_tree$size, cv_tree$dev, type="b")

##prune the tree
pruned_model=prune.misclass(tree_model, best=9)
plot(pruned_model)
text(pruned_model, pretty=0)

##check how it is doing
tree_pred=predict(pruned_model, testing_data, type="class")
mean(tree_pred !=testing_High)
## [1] 0.23