Árboles de Decisión

Librerias

Carga de librerías requeridas

library(ISLR)
## Warning: package 'ISLR' was built under R version 4.3.3
library(tree)
## Warning: package 'tree' was built under R version 4.3.3
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(ggplot2)
## Warning: package 'ggplot2' was built under R version 4.3.3
library(caret)
## Loading required package: lattice
library(MASS)
## 
## Attaching package: 'MASS'
## The following object is masked from 'package:dplyr':
## 
##     select

Base “Hitters”

data(Hitters, package="ISLR")
glimpse(Hitters)
## Rows: 322
## Columns: 20
## $ AtBat     <int> 293, 315, 479, 496, 321, 594, 185, 298, 323, 401, 574, 202, …
## $ Hits      <int> 66, 81, 130, 141, 87, 169, 37, 73, 81, 92, 159, 53, 113, 60,…
## $ HmRun     <int> 1, 7, 18, 20, 10, 4, 1, 0, 6, 17, 21, 4, 13, 0, 7, 3, 20, 2,…
## $ Runs      <int> 30, 24, 66, 65, 39, 74, 23, 24, 26, 49, 107, 31, 48, 30, 29,…
## $ RBI       <int> 29, 38, 72, 78, 42, 51, 8, 24, 32, 66, 75, 26, 61, 11, 27, 1…
## $ Walks     <int> 14, 39, 76, 37, 30, 35, 21, 7, 8, 65, 59, 27, 47, 22, 30, 11…
## $ Years     <int> 1, 14, 3, 11, 2, 11, 2, 3, 2, 13, 10, 9, 4, 6, 13, 3, 15, 5,…
## $ CAtBat    <int> 293, 3449, 1624, 5628, 396, 4408, 214, 509, 341, 5206, 4631,…
## $ CHits     <int> 66, 835, 457, 1575, 101, 1133, 42, 108, 86, 1332, 1300, 467,…
## $ CHmRun    <int> 1, 69, 63, 225, 12, 19, 1, 0, 6, 253, 90, 15, 41, 4, 36, 3, …
## $ CRuns     <int> 30, 321, 224, 828, 48, 501, 30, 41, 32, 784, 702, 192, 205, …
## $ CRBI      <int> 29, 414, 266, 838, 46, 336, 9, 37, 34, 890, 504, 186, 204, 1…
## $ CWalks    <int> 14, 375, 263, 354, 33, 194, 24, 12, 8, 866, 488, 161, 203, 2…
## $ League    <fct> A, N, A, N, N, A, N, A, N, A, A, N, N, A, N, A, N, A, A, N, …
## $ Division  <fct> E, W, W, E, E, W, E, W, W, E, E, W, E, E, E, W, W, W, W, W, …
## $ PutOuts   <int> 446, 632, 880, 200, 805, 282, 76, 121, 143, 0, 238, 304, 211…
## $ Assists   <int> 33, 43, 82, 11, 40, 421, 127, 283, 290, 0, 445, 45, 11, 151,…
## $ Errors    <int> 20, 10, 14, 3, 4, 25, 7, 9, 19, 0, 22, 11, 7, 6, 8, 0, 10, 1…
## $ Salary    <dbl> NA, 475.000, 480.000, 500.000, 91.500, 750.000, 70.000, 100.…
## $ NewLeague <fct> A, N, A, N, N, A, A, A, N, A, A, N, N, A, N, A, N, A, A, N, …
attach(Hitters)

Limpieza de datos

sum(is.na(Hitters))
## [1] 59
colSums(is.na(Hitters)) 
##     AtBat      Hits     HmRun      Runs       RBI     Walks     Years    CAtBat 
##         0         0         0         0         0         0         0         0 
##     CHits    CHmRun     CRuns      CRBI    CWalks    League  Division   PutOuts 
##         0         0         0         0         0         0         0         0 
##   Assists    Errors    Salary NewLeague 
##         0         0        59         0
summary(Hitters)
##      AtBat            Hits         HmRun            Runs       
##  Min.   : 16.0   Min.   :  1   Min.   : 0.00   Min.   :  0.00  
##  1st Qu.:255.2   1st Qu.: 64   1st Qu.: 4.00   1st Qu.: 30.25  
##  Median :379.5   Median : 96   Median : 8.00   Median : 48.00  
##  Mean   :380.9   Mean   :101   Mean   :10.77   Mean   : 50.91  
##  3rd Qu.:512.0   3rd Qu.:137   3rd Qu.:16.00   3rd Qu.: 69.00  
##  Max.   :687.0   Max.   :238   Max.   :40.00   Max.   :130.00  
##                                                                
##       RBI             Walks            Years            CAtBat       
##  Min.   :  0.00   Min.   :  0.00   Min.   : 1.000   Min.   :   19.0  
##  1st Qu.: 28.00   1st Qu.: 22.00   1st Qu.: 4.000   1st Qu.:  816.8  
##  Median : 44.00   Median : 35.00   Median : 6.000   Median : 1928.0  
##  Mean   : 48.03   Mean   : 38.74   Mean   : 7.444   Mean   : 2648.7  
##  3rd Qu.: 64.75   3rd Qu.: 53.00   3rd Qu.:11.000   3rd Qu.: 3924.2  
##  Max.   :121.00   Max.   :105.00   Max.   :24.000   Max.   :14053.0  
##                                                                      
##      CHits            CHmRun           CRuns             CRBI        
##  Min.   :   4.0   Min.   :  0.00   Min.   :   1.0   Min.   :   0.00  
##  1st Qu.: 209.0   1st Qu.: 14.00   1st Qu.: 100.2   1st Qu.:  88.75  
##  Median : 508.0   Median : 37.50   Median : 247.0   Median : 220.50  
##  Mean   : 717.6   Mean   : 69.49   Mean   : 358.8   Mean   : 330.12  
##  3rd Qu.:1059.2   3rd Qu.: 90.00   3rd Qu.: 526.2   3rd Qu.: 426.25  
##  Max.   :4256.0   Max.   :548.00   Max.   :2165.0   Max.   :1659.00  
##                                                                      
##      CWalks        League  Division    PutOuts          Assists     
##  Min.   :   0.00   A:175   E:157    Min.   :   0.0   Min.   :  0.0  
##  1st Qu.:  67.25   N:147   W:165    1st Qu.: 109.2   1st Qu.:  7.0  
##  Median : 170.50                    Median : 212.0   Median : 39.5  
##  Mean   : 260.24                    Mean   : 288.9   Mean   :106.9  
##  3rd Qu.: 339.25                    3rd Qu.: 325.0   3rd Qu.:166.0  
##  Max.   :1566.00                    Max.   :1378.0   Max.   :492.0  
##                                                                     
##      Errors          Salary       NewLeague
##  Min.   : 0.00   Min.   :  67.5   A:176    
##  1st Qu.: 3.00   1st Qu.: 190.0   N:146    
##  Median : 6.00   Median : 425.0            
##  Mean   : 8.04   Mean   : 535.9            
##  3rd Qu.:11.00   3rd Qu.: 750.0            
##  Max.   :32.00   Max.   :2460.0            
##                  NA's   :59
# Se imputan los 59 NA en la columna "Salary" a la mediana. 
Hitters$Salary[is.na(Hitters$Salary)] <- median(Hitters$Salary, na.rm =TRUE)

Analisis Exploratorio de los Datos

hist(Hitters$Salary, main = "Distribución del Salario", xlab = "Salario")

El histograma muestra una distribución aparentemente sesgada hacia la derecha o asimétrica positiva. Esto significa que hay una mayor concentración de salarios más bajos y una cola larga hacia los salarios más altos. La distribución de los salarios va desde los 0 a 2500 dólares. Anque se puede ver que hay picos en los salarios menores a 500 dólares.

boxplot(Salary ~ League, data = Hitters, main = "Salario por Liga",
        ylab = "Salario", xlab = "Liga")

La mediana del salario parece ser similar entre las ligas que en otras, lo que sugiere que hay igualdad en el nivel salarial promedio entre las ligas.La dispersión de los salarios (representada por el tamaño de las cajas y la longitud de los bigotes) también puede variar entre las ligas, lo que indica diferencias en la variabilidad de los salarios. Se puede identificar la presencia de valores atípicos salarios excepcionalmente altos en la liga A.

salary_by_experience <- Hitters %>%
  group_by(Years) %>%
  summarise(AverageSalary = mean(Salary, na.rm = TRUE)) %>%
  arrange(Years)

barplot(salary_by_experience$AverageSalary, names.arg = salary_by_experience$Years, main = "Salario Promedio por Años de Experiencia", ylab = "Salario Promedio", col = "green")

Se puede obervar que conforme va aumentando la cantidad de años de experiencia el salario también, sin embargo hay que mencionar que después de los 13 años el salario va disminuyendo, aunque la carrera de jugadores de béisbol es larga, se puede ver que el salario llega a su máximo después de los 13 años jugando

Árboles de Decisión

Hitters$log_salary <- log(Hitters$Salary)
Hitters <- subset(Hitters, select = -c(Hits, Years, Salary))


set.seed(20240424) 
train <- sample(1:nrow(Hitters),nrow(Hitters)*0.7)
tree.hitters=tree(log_salary~.,Hitters,subset=train)
summary(tree.hitters)
## 
## Regression tree:
## tree(formula = log_salary ~ ., data = Hitters, subset = train)
## Variables actually used in tree construction:
## [1] "CAtBat" "CHits"  "CRBI"   "AtBat"  "CWalks"
## Number of terminal nodes:  11 
## Residual mean deviance:  0.1504 = 32.18 / 214 
## Distribution of residuals:
##      Min.   1st Qu.    Median      Mean   3rd Qu.      Max. 
## -1.455000 -0.226500 -0.008238  0.000000  0.243500  1.313000

Se utilizan 5 variables: ’CAtBat” “CHits” “CRBI” “AtBat” “CWalks”

plot(tree.hitters)
text(tree.hitters, cex=0.7, col="darkgreen")

Validación Cruzada

cv.hitters <- cv.tree(tree.hitters)
plot(cv.hitters$size,cv.hitters$dev,type='b')

prune.hitters <- prune.tree(tree.hitters ,best=6)
plot(prune.hitters)
text(prune.hitters,cex=0.9, col="darkgreen")

yhat <- predict(prune.hitters,newdata=Hitters[-train,]) 
hitters.test <- Hitters[-train ,"log_salary"] 
plot(yhat,hitters.test) #para ver si el modelo es bueno o malo
abline(0,1)

mean((yhat-hitters.test)^2)
## [1] 0.3555823

Bosques Aleatorios y Empaquetados (Bagging and Random Forests)

library(randomForest)
## Warning: package 'randomForest' was built under R version 4.3.3
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
## 
##     margin
## The following object is masked from 'package:dplyr':
## 
##     combine
set.seed(20240424)
bag.hitters <- randomForest(log_salary~.,data=Hitters, subset=train, mtry=13, importance =TRUE)
bag.hitters
## 
## Call:
##  randomForest(formula = log_salary ~ ., data = Hitters, mtry = 13,      importance = TRUE, subset = train) 
##                Type of random forest: regression
##                      Number of trees: 500
## No. of variables tried at each split: 13
## 
##           Mean of squared residuals: 0.2445683
##                     % Var explained: 63.32

En los siguientes códigos se evalúa el rendimiento del modelo Bagging:

yhat.bag <- predict(bag.hitters , newdata=Hitters[-train ,]) #obtain predictions in test set
plot(yhat.bag , hitters.test) #plot the predictions against the observed values
abline (0,1) #add the perfect prediction line

mean((yhat.bag -hitters.test)^2) #calculate test MSE
## [1] 0.2830136
set.seed(20240425)
bag.hitters <- randomForest(log_salary~.,data=Hitters, subset=train, mtry=13, ntree=25)
yhat.bag <- predict(bag.hitters,newdata=Hitters[-train ,]) 
mean((yhat.bag-hitters.test)^2)
## [1] 0.3002457
set.seed(20240424)
rf.hitters <- randomForest(log_salary~.,data=Hitters,subset=train,mtry=6, importance =TRUE)
yhat.rf <- predict(rf.hitters,newdata=Hitters[-train,])
mean((yhat.rf-hitters.test)^2)
## [1] 0.2777814
importance(bag.hitters)
##           IncNodePurity
## AtBat         8.7881940
## HmRun         2.0034403
## Runs          9.9732957
## RBI           3.6326145
## Walks         7.4063329
## CAtBat       49.3498405
## CHits        34.9412363
## CHmRun        3.3510743
## CRuns         5.8835190
## CRBI          3.2431134
## CWalks       12.3597555
## League        0.6267698
## Division      0.1512166
## PutOuts       1.6700305
## Assists       1.6668453
## Errors        3.6017725
## NewLeague     0.4475874
varImpPlot(bag.hitters)

treefit <- tree(log_salary ~ CAtBat + Runs, data = Hitters)
print(treefit)
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 322 207.900 5.950  
##    2) CAtBat < 1452 133  57.600 5.309  
##      4) CAtBat < 699 70  36.800 5.033  
##        8) CAtBat < 173.5 8   2.481 6.350 *
##        9) CAtBat > 173.5 62  18.650 4.863 *
##      5) CAtBat > 699 63   9.497 5.616 *
##    3) CAtBat > 1452 189  57.250 6.401  
##      6) Runs < 55 92  18.580 6.156 *
##      7) Runs > 55 97  27.920 6.633  
##       14) CAtBat < 2198.5 23   2.800 6.285 *
##       15) CAtBat > 2198.5 74  21.460 6.742 *
plot(treefit)
text(treefit, cex=0.7, col="darkgreen")

ggplot(data=Hitters, aes(x=CAtBat, y=Runs, group=Salary)) +
  geom_point(aes(color=Salary)) +
  scale_color_continuous() +
  geom_segment(aes(x = 173.5, y = 0, xend = 173.5, yend = 130)) +
  geom_segment(aes(x = 699, y = 0, xend = 699, yend = 130)) +
  geom_segment(aes(x = 1452, y = 0, xend = 1452, yend = 130)) +
  geom_segment(aes(x = 2198.5, y = 0, xend = 2198.5, yend = 55)) +
  geom_segment(aes(x = 1452, y = 55, xend = 14053,yend = 55))