所有課程補充資料、投影片皆位於

預測分析(Predictive Analysis)

要預測什麼?

  • 個人或特定行為(例如:行動、事件或發生狀況) e.g. 每位顧客最可能點擊哪則廣告?

要做什麼用?

  • 根據預測所做的決定,組織回應各項預測,或從預測得知應採取行動

簡單線性迴歸

X <- c(17, 21,35, 39, 50, 65)
Y <- c(132, 150, 160, 162, 149, 170)

X_avg <- mean(X)
Y_avg <- mean(Y)
# 計算迴歸係數
B1 <- (sum(X*Y)- length(X) * X_avg * Y_avg)/(sum(X^2)- length(X) * X_avg^2)

# 計算截距
B0 <- Y_avg-B1 *X_avg

x <- 1:70
y <- B1 * x + B0
plot(X, Y)
lines(x, y, col='red')

X <- c(17,   21,  35,  39,  50,  65)
Y <- c(132, 150, 160, 162, 149, 170)

data <- data.frame(X= X, Y=Y)
fit <- lm(Y ~ X)
summary(fit)
## 
## Call:
## lm(formula = Y ~ X)
## 
## Residuals:
##       1       2       3       4       5       6 
## -10.313   5.475   7.733   7.522 -11.561   1.145 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)    
## (Intercept) 132.9130    10.1079  13.149 0.000193 ***
## X             0.5530     0.2451   2.256 0.087095 .  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 9.845 on 4 degrees of freedom
## Multiple R-squared:  0.5599, Adjusted R-squared:  0.4498 
## F-statistic: 5.088 on 1 and 4 DF,  p-value: 0.08709
plot(X, Y)
abline(fit, col="red")

一項式回歸分析

library(car)
data(Quartet)

lmfit=lm(y2~x, data = Quartet)
summary(lmfit)
## 
## Call:
## lm(formula = y2 ~ x, data = Quartet)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -1.9009 -0.7609  0.1291  0.9491  1.2691 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)   
## (Intercept)    3.001      1.125   2.667  0.02576 * 
## x              0.500      0.118   4.239  0.00218 **
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 1.237 on 9 degrees of freedom
## Multiple R-squared:  0.6662, Adjusted R-squared:  0.6292 
## F-statistic: 17.97 on 1 and 9 DF,  p-value: 0.002179
cf <- summary(lmfit)$coefficients
print( paste( 'FORMULA y = ' ,cf[2,1] , ' * ' ,'x' , 
                     ' + ', cf[1,1]))
## [1] "FORMULA y =  0.5  *  x  +  3.00090909090909"
plot(y2 ~x, data= Quartet)
abline(lmfit, col="red")

# 測試資料
xp <- c(6,12)
# 產生預測結果
yp <- predict(lmfit, data.frame(x=xp))
points(x=xp, y =yp, col='blue')

print(data.frame(test= xp, predict = yp))
##   test  predict
## 1    6 6.000909
## 2   12 9.000909
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
## 二項式回歸
Quartet <- Quartet %>%
  arrange(x)
lmfit2 = lm(y2 ~poly(x,2), data= Quartet)
summary(lmfit2)
## 
## Call:
## lm(formula = y2 ~ poly(x, 2), data = Quartet)
## 
## Residuals:
##        Min         1Q     Median         3Q        Max 
## -0.0013287 -0.0011888 -0.0006294  0.0008741  0.0023776 
## 
## Coefficients:
##               Estimate Std. Error t value Pr(>|t|)    
## (Intercept)  7.5009091  0.0005043   14875   <2e-16 ***
## poly(x, 2)1  5.2440442  0.0016725    3135   <2e-16 ***
## poly(x, 2)2 -3.7116396  0.0016725   -2219   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 0.001672 on 8 degrees of freedom
## Multiple R-squared:      1,  Adjusted R-squared:      1 
## F-statistic: 7.378e+06 on 2 and 8 DF,  p-value: < 2.2e-16
cf2 <- summary(lmfit2)$coefficients
print( paste( 'FORMULA y = ' ,cf2[3,1] , ' * ' ,'x^2' , 
                     ' + ', cf2[2,1],
                     ' + ', cf2[1,1]))
## [1] "FORMULA y =  -3.71163960150612  *  x^2  +  5.24404424085076  +  7.50090909090909"
lines(Quartet$x, lmfit2$fit, col="red")

透過rlm (Robust Fitting of Linear Models) 會排除離群值建立線性回歸

fit3 <- lm(y3~x, data = Quartet)
summary(fit3)
## 
## Call:
## lm(formula = y3 ~ x, data = Quartet)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -1.1586 -0.6146 -0.2303  0.1540  3.2411 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)   
## (Intercept)   3.0025     1.1245   2.670  0.02562 * 
## x             0.4997     0.1179   4.239  0.00218 **
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 1.236 on 9 degrees of freedom
## Multiple R-squared:  0.6663, Adjusted R-squared:  0.6292 
## F-statistic: 17.97 on 1 and 9 DF,  p-value: 0.002176
plot(y3~x, data= Quartet)
lines(Quartet$x, fit3$fit, col="red")

library(MASS)
## 
## Attaching package: 'MASS'
## The following object is masked from 'package:dplyr':
## 
##     select
fit4 <- rlm(y3~x, data = Quartet)
summary(fit4)
## 
## Call: rlm(formula = y3 ~ x, data = Quartet)
## Residuals:
##        Min         1Q     Median         3Q        Max 
## -0.0049962 -0.0028591 -0.0007219  0.0028667  4.2421008 
## 
## Coefficients:
##             Value    Std. Error t value 
## (Intercept)   4.0035   0.0040   990.3355
## x             0.3457   0.0004   815.8284
## 
## Residual standard error: 0.005248 on 9 degrees of freedom
lines(Quartet$x, fit4$fit, col="blue")

抓591租屋網建立線性回歸

setwd("D:/CXLWorkspace/RS/RML")
if(!file.exists('houseprice.csv')) {
  download.file('https://raw.githubusercontent.com/ywchiu/rtibame/master/History/Class1/591.csv', destfile = 'houseprice.csv')  
}

house <- read.csv("houseprice.csv", header = TRUE)

fithm <- lm(Price ~ Area, data= house)
summary(fithm)
## 
## Call:
## lm(formula = Price ~ Area, data = house)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -103542  -13882   -1784   13432  255710 
## 
## Coefficients:
##              Estimate Std. Error t value Pr(>|t|)    
## (Intercept) -12811.23    2234.57  -5.733 1.51e-08 ***
## Area          1768.70      44.72  39.551  < 2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 27590 on 646 degrees of freedom
## Multiple R-squared:  0.7077, Adjusted R-squared:  0.7073 
## F-statistic:  1564 on 1 and 646 DF,  p-value: < 2.2e-16
plot(Price ~ Area, data= house, main="House Price", xlab="Area", ylab="Price")

abline(fithm, col= 'green')

最小平方估計法OLS 適用的誤差(隨機項)假設

驗證迴歸模型

線性關係顯著性檢定

如資料之間沒有相關性,𝛽1應該接近於0,但事實上𝛽1不太可能等於零,因此要如何驗證𝛽1近似於0?

Coefficients:
             Estimate Std. Error t value Pr(>|t|)   
 (Intercept)    3.001      1.125   2.667  0.02576 * 
 x              0.500      0.118   4.239  0.00218 **    
線性回歸變數的p-value 越小(<0.05),自變數與依變數的關聯性越強

P 值代表著是「機率」:也就是虛無假設為「真」時,從樣本資料來作檢定會得到的機率
P值越小,則「反對」虛無假設的證據越充分

RSquare

評估一般線性模型

evaluateModel <- function(actual, predicted) {
  rmse <- (mean((predicted -actual)^2))^0.5

  mu <- mean(actual)
  rse <- mean((predicted -actual)^2)/mean((mu -actual)^2)

  rsquare = 1-rse
  print(paste('RMSE:',rmse)  )
  print(paste('R2:',rsquare))
  return (rsquare)
}

一般線性模型

lmfit <- lm(Quartet$y3~Quartet$x)
plot(Quartet$x, Quartet$y3)
abline(lmfit, col="red")

predicted <- predict(lmfit, newdata=Quartet[c("x")])
actual <- Quartet$y3



summary(lmfit)
## 
## Call:
## lm(formula = Quartet$y3 ~ Quartet$x)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -1.1586 -0.6146 -0.2303  0.1540  3.2411 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)   
## (Intercept)   3.0025     1.1245   2.670  0.02562 * 
## Quartet$x     0.4997     0.1179   4.239  0.00218 **
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 1.236 on 9 degrees of freedom
## Multiple R-squared:  0.6663, Adjusted R-squared:  0.6292 
## F-statistic: 17.97 on 1 and 9 DF,  p-value: 0.002176
r2 <-  evaluateModel(actual, predicted)
## [1] "RMSE: 1.11828569362305"
## [1] "R2: 0.666324041066559"

rlm

rlmfit <- rlm(Quartet$y3~Quartet$x)
plot(Quartet$x, Quartet$y3)
abline(rlmfit, col="red")

predicted <- predict(rlmfit, newdata=Quartet[c("x")])
actual <- Quartet$y3

summary(rlmfit)
## 
## Call: rlm(formula = Quartet$y3 ~ Quartet$x)
## Residuals:
##        Min         1Q     Median         3Q        Max 
## -0.0049962 -0.0028591 -0.0007219  0.0028667  4.2421008 
## 
## Coefficients:
##             Value    Std. Error t value 
## (Intercept)   4.0035   0.0040   990.3355
## Quartet$x     0.3457   0.0004   815.8284
## 
## Residual standard error: 0.005248 on 9 degrees of freedom
r2 <- evaluateModel(actual, predicted)
## [1] "RMSE: 1.27904477056295"
## [1] "R2: 0.563493342191402"

模型驗證總結

setwd("D:/CXLWorkspace/RS/RML")
if(!file.exists('house-prices.csv')) {
    download.file('https://raw.githubusercontent.com/ywchiu/cathayml/master/dataset/house-prices.csv', 
destfile = 'house-prices.csv')
}
house_prices <- read.csv('house-prices.csv', header= TRUE)
head(house_prices)
##   Home  Price SqFt Bedrooms Bathrooms Offers Brick Neighborhood
## 1    1 114300 1790        2         2      2    No         East
## 2    2 114200 2030        4         2      3    No         East
## 3    3 114800 1740        3         2      1    No         East
## 4    4  94700 1980        3         2      3    No         East
## 5    5 119800 2130        3         3      3    No         East
## 6    6 114600 1780        3         2      2    No        North
lm.1 <- lm(Price~SqFt, data= house_prices)

predicted <- predict(lm.1, newdata=house_prices[c("SqFt")])
actual <- house_prices$Price

#F統計(F-Statistic) = 迴歸平方平均(Model Mean Square) / 誤差平方平均(Error Mean Square)
summary(lm.1)
## 
## Call:
## lm(formula = Price ~ SqFt, data = house_prices)
## 
## Residuals:
##    Min     1Q Median     3Q    Max 
## -46593 -16644  -1610  15124  54829 
## 
## Coefficients:
##               Estimate Std. Error t value Pr(>|t|)    
## (Intercept) -10091.130  18966.104  -0.532    0.596    
## SqFt            70.226      9.426   7.450  1.3e-11 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 22480 on 126 degrees of freedom
## Multiple R-squared:  0.3058, Adjusted R-squared:  0.3003 
## F-statistic:  55.5 on 1 and 126 DF,  p-value: 1.302e-11
evaluateModel(actual, predicted)
## [1] "RMSE: 22299.2522370143"
## [1] "R2: 0.305789360581692"
## [1] 0.3057894
#檢視殘差圖
hist(lm.1$residuals)

#類型變數都轉換成虛擬的數值變數,會陷入Dummy variable trap,違反共線性假設
house_prices$brick_d<-ifelse(house_prices$Brick=="Yes",1,0)
house_prices$east<-ifelse(house_prices$Neighborhood=="East",1,0)
house_prices$north<-ifelse(house_prices$Neighborhood=="North",1,0)

建立訓練與測試資料集,訓練與評估測試線性回歸模型

set.seed(110)
sub<-sample(nrow(house_prices), floor(nrow(house_prices)*0.6))

training_data<-house_prices[sub,]
validation_data<-house_prices[-sub,]


lm.fit1 <-lm(Price ~SqFt+Bathrooms+Bedrooms+Offers+north+east+brick_d, data=training_data)

summary(lm.fit1)
## 
## Call:
## lm(formula = Price ~ SqFt + Bathrooms + Bedrooms + Offers + north + 
##     east + brick_d, data = training_data)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -28809.1  -5439.8   -251.1   5716.9  26720.7 
## 
## Coefficients:
##               Estimate Std. Error t value Pr(>|t|)    
## (Intercept)  33263.403  13018.025   2.555   0.0129 *  
## SqFt            48.918      7.103   6.887 2.26e-09 ***
## Bathrooms     4886.975   2746.714   1.779   0.0797 .  
## Bedrooms      4352.011   1971.204   2.208   0.0306 *  
## Offers       -5655.299   1314.227  -4.303 5.52e-05 ***
## north       -23296.029   3891.067  -5.987 8.96e-08 ***
## east        -22978.967   3063.771  -7.500 1.77e-10 ***
## brick_d      18500.732   2379.993   7.773 5.65e-11 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 9244 on 68 degrees of freedom
## Multiple R-squared:  0.8864, Adjusted R-squared:  0.8747 
## F-statistic: 75.77 on 7 and 68 DF,  p-value: < 2.2e-16

減少不顯著的變數

step(lm.fit1)
## Start:  AIC=1395.57
## Price ~ SqFt + Bathrooms + Bedrooms + Offers + north + east + 
##     brick_d
## 
##             Df  Sum of Sq        RSS    AIC
## <none>                    5.8106e+09 1395.6
## - Bathrooms  1  270499345 6.0811e+09 1397.0
## - Bedrooms   1  416513849 6.2271e+09 1398.8
## - Offers     1 1582279549 7.3929e+09 1411.9
## - north      1 3062944430 8.8736e+09 1425.8
## - SqFt       1 4053393477 9.8640e+09 1433.8
## - east       1 4806858174 1.0617e+10 1439.4
## - brick_d    1 5163439960 1.0974e+10 1441.9
## 
## Call:
## lm(formula = Price ~ SqFt + Bathrooms + Bedrooms + Offers + north + 
##     east + brick_d, data = training_data)
## 
## Coefficients:
## (Intercept)         SqFt    Bathrooms     Bedrooms       Offers  
##    33263.40        48.92      4886.98      4352.01     -5655.30  
##       north         east      brick_d  
##   -23296.03    -22978.97     18500.73
library(MASS)
lm.fit1.step <-stepAIC(lm.fit1)
## Start:  AIC=1395.57
## Price ~ SqFt + Bathrooms + Bedrooms + Offers + north + east + 
##     brick_d
## 
##             Df  Sum of Sq        RSS    AIC
## <none>                    5.8106e+09 1395.6
## - Bathrooms  1  270499345 6.0811e+09 1397.0
## - Bedrooms   1  416513849 6.2271e+09 1398.8
## - Offers     1 1582279549 7.3929e+09 1411.9
## - north      1 3062944430 8.8736e+09 1425.8
## - SqFt       1 4053393477 9.8640e+09 1433.8
## - east       1 4806858174 1.0617e+10 1439.4
## - brick_d    1 5163439960 1.0974e+10 1441.9
summary(lm.fit1.step)
## 
## Call:
## lm(formula = Price ~ SqFt + Bathrooms + Bedrooms + Offers + north + 
##     east + brick_d, data = training_data)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -28809.1  -5439.8   -251.1   5716.9  26720.7 
## 
## Coefficients:
##               Estimate Std. Error t value Pr(>|t|)    
## (Intercept)  33263.403  13018.025   2.555   0.0129 *  
## SqFt            48.918      7.103   6.887 2.26e-09 ***
## Bathrooms     4886.975   2746.714   1.779   0.0797 .  
## Bedrooms      4352.011   1971.204   2.208   0.0306 *  
## Offers       -5655.299   1314.227  -4.303 5.52e-05 ***
## north       -23296.029   3891.067  -5.987 8.96e-08 ***
## east        -22978.967   3063.771  -7.500 1.77e-10 ***
## brick_d      18500.732   2379.993   7.773 5.65e-11 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 9244 on 68 degrees of freedom
## Multiple R-squared:  0.8864, Adjusted R-squared:  0.8747 
## F-statistic: 75.77 on 7 and 68 DF,  p-value: < 2.2e-16

使用 step 透過AIC 挑選模型自變數

house <- read.csv("houseprice.csv", header = TRUE)

fithm <- lm(Price ~ ., data= house)
summary(fithm)
## 
## Call:
## lm(formula = Price ~ ., data = house)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -124298  -12176    -275   10631  233276 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)    
## (Intercept) -2834.46    4939.84  -0.574 0.566306    
## Area         1908.46      70.35  27.130  < 2e-16 ***
## Floor        1136.90     326.39   3.483 0.000529 ***
## TotalFloor   -268.31     265.24  -1.012 0.312129    
## Bedroom     -9081.26    1390.83  -6.529 1.34e-10 ***
## Living.Room   458.81    2786.81   0.165 0.869283    
## Bathroom     2757.43    2480.95   1.111 0.266796    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 26440 on 641 degrees of freedom
## Multiple R-squared:  0.7337, Adjusted R-squared:  0.7312 
## F-statistic: 294.3 on 6 and 641 DF,  p-value: < 2.2e-16
stepfithm <- step(fithm)
## Start:  AIC=13203.49
## Price ~ Area + Floor + TotalFloor + Bedroom + Living.Room + Bathroom
## 
##               Df  Sum of Sq        RSS   AIC
## - Living.Room  1 1.8943e+07 4.4801e+11 13202
## - TotalFloor   1 7.1515e+08 4.4871e+11 13202
## - Bathroom     1 8.6335e+08 4.4886e+11 13203
## <none>                      4.4799e+11 13204
## - Floor        1 8.4798e+09 4.5647e+11 13214
## - Bedroom      1 2.9796e+10 4.7779e+11 13243
## - Area         1 5.1440e+11 9.6240e+11 13697
## 
## Step:  AIC=13201.52
## Price ~ Area + Floor + TotalFloor + Bedroom + Bathroom
## 
##              Df  Sum of Sq        RSS   AIC
## - TotalFloor  1 7.2245e+08 4.4873e+11 13201
## - Bathroom    1 9.9776e+08 4.4901e+11 13201
## <none>                     4.4801e+11 13202
## - Floor       1 8.4621e+09 4.5647e+11 13212
## - Bedroom     1 3.1079e+10 4.7909e+11 13243
## - Area        1 5.1871e+11 9.6673e+11 13698
## 
## Step:  AIC=13200.56
## Price ~ Area + Floor + Bedroom + Bathroom
## 
##            Df  Sum of Sq        RSS   AIC
## - Bathroom  1 1.1752e+09 4.4991e+11 13200
## <none>                   4.4873e+11 13201
## - Floor     1 8.5597e+09 4.5729e+11 13211
## - Bedroom   1 3.0550e+10 4.7928e+11 13241
## - Area      1 5.5375e+11 1.0025e+12 13719
## 
## Step:  AIC=13200.26
## Price ~ Area + Floor + Bedroom
## 
##           Df  Sum of Sq        RSS   AIC
## <none>                  4.4991e+11 13200
## - Floor    1 8.7129e+09 4.5862e+11 13211
## - Bedroom  1 3.0410e+10 4.8032e+11 13241
## - Area     1 9.4161e+11 1.3915e+12 13930
summary(stepfithm)
## 
## Call:
## lm(formula = Price ~ Area + Floor + Bedroom, data = house)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -125676  -12107    -201   10241  234913 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)    
## (Intercept) -2575.44    3451.37  -0.746 0.455814    
## Area         1942.38      52.91  36.713  < 2e-16 ***
## Floor         959.13     271.59   3.532 0.000443 ***
## Bedroom     -8284.05    1255.61  -6.598 8.72e-11 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 26430 on 644 degrees of freedom
## Multiple R-squared:  0.7325, Adjusted R-squared:  0.7313 
## F-statistic: 587.9 on 3 and 644 DF,  p-value: < 2.2e-16

檢視共線性

vif(lm.fit1)
##      SqFt Bathrooms  Bedrooms    Offers     north      east   brick_d 
##  2.077096  1.760002  1.816514  1.798679  3.084402  1.879012  1.172261
vif(fithm)
##        Area       Floor  TotalFloor     Bedroom Living.Room    Bathroom 
##    2.694499    1.485638    1.633546    1.828493    1.613878    2.917688
vif(stepfithm)
##     Area    Floor  Bedroom 
## 1.524799 1.029081 1.490825

檢視殘值

##訓練資料
training_data$predict.price <- predict(lm.fit1)
training_data$error <- residuals(lm.fit1)

##測試資料
validation_data$predict.price <- predict(lm.fit1,newdata=validation_data)
validation_data$error <- validation_data$predict.price-validation_data$Price

par(mfrow=c(1,2))
hist(training_data$error, main = 'Train Data Error') 

hist(validation_data$error, main = 'Validate Data Error') 

檢視R Square

## 訓練資料
a<-cor(training_data$Price,training_data$predict.price)
a*a
## [1] 0.8863606
## 驗證資料
b<-cor(validation_data$Price,validation_data$predict.price)
b*b
## [1] 0.840097

決策樹

使用rpart分類鳶尾花資料

library(rpart)
data(iris)
names(iris)
## [1] "Sepal.Length" "Sepal.Width"  "Petal.Length" "Petal.Width" 
## [5] "Species"
iris.pch = ifelse(iris$Species == 'setosa' , 1 ,  ifelse(iris$Species == 'virginica' , 3 ,  5))

plot(iris[,-5], col=iris$Species , pch= iris.pch)

fit <-rpart(Species ~., data=iris)
summary(fit)
## Call:
## rpart(formula = Species ~ ., data = iris)
##   n= 150 
## 
##     CP nsplit rel error xerror       xstd
## 1 0.50      0      1.00   1.18 0.05017303
## 2 0.44      1      0.50   0.64 0.06057502
## 3 0.01      2      0.06   0.08 0.02751969
## 
## Variable importance
##  Petal.Width Petal.Length Sepal.Length  Sepal.Width 
##           34           31           21           14 
## 
## Node number 1: 150 observations,    complexity param=0.5
##   predicted class=setosa      expected loss=0.6666667  P(node) =1
##     class counts:    50    50    50
##    probabilities: 0.333 0.333 0.333 
##   left son=2 (50 obs) right son=3 (100 obs)
##   Primary splits:
##       Petal.Length < 2.45 to the left,  improve=50.00000, (0 missing)
##       Petal.Width  < 0.8  to the left,  improve=50.00000, (0 missing)
##       Sepal.Length < 5.45 to the left,  improve=34.16405, (0 missing)
##       Sepal.Width  < 3.35 to the right, improve=19.03851, (0 missing)
##   Surrogate splits:
##       Petal.Width  < 0.8  to the left,  agree=1.000, adj=1.00, (0 split)
##       Sepal.Length < 5.45 to the left,  agree=0.920, adj=0.76, (0 split)
##       Sepal.Width  < 3.35 to the right, agree=0.833, adj=0.50, (0 split)
## 
## Node number 2: 50 observations
##   predicted class=setosa      expected loss=0  P(node) =0.3333333
##     class counts:    50     0     0
##    probabilities: 1.000 0.000 0.000 
## 
## Node number 3: 100 observations,    complexity param=0.44
##   predicted class=versicolor  expected loss=0.5  P(node) =0.6666667
##     class counts:     0    50    50
##    probabilities: 0.000 0.500 0.500 
##   left son=6 (54 obs) right son=7 (46 obs)
##   Primary splits:
##       Petal.Width  < 1.75 to the left,  improve=38.969400, (0 missing)
##       Petal.Length < 4.75 to the left,  improve=37.353540, (0 missing)
##       Sepal.Length < 6.15 to the left,  improve=10.686870, (0 missing)
##       Sepal.Width  < 2.45 to the left,  improve= 3.555556, (0 missing)
##   Surrogate splits:
##       Petal.Length < 4.75 to the left,  agree=0.91, adj=0.804, (0 split)
##       Sepal.Length < 6.15 to the left,  agree=0.73, adj=0.413, (0 split)
##       Sepal.Width  < 2.95 to the left,  agree=0.67, adj=0.283, (0 split)
## 
## Node number 6: 54 observations
##   predicted class=versicolor  expected loss=0.09259259  P(node) =0.36
##     class counts:     0    49     5
##    probabilities: 0.000 0.907 0.093 
## 
## Node number 7: 46 observations
##   predicted class=virginica   expected loss=0.02173913  P(node) =0.3066667
##     class counts:     0     1    45
##    probabilities: 0.000 0.022 0.978
plot(fit, margin =0.1)
text(fit)

plot(Petal.Width ~Petal.Length, data= iris, col=Species)
abline(v=2.45, col= "blue")
abline(h=1.75 , col= "orange")

## 觀看分類結果
## type=class, 選擇機率最高類別的當預測結果
table(predict(fit, iris[,1:4], type="class"), iris[,5])
##             
##              setosa versicolor virginica
##   setosa         50          0         0
##   versicolor      0         49         5
##   virginica       0          1        45

使用party ctree分類鳶尾花資料

library(party)
## Loading required package: grid
## Loading required package: mvtnorm
## Loading required package: modeltools
## Loading required package: stats4
## Loading required package: strucchange
## Loading required package: zoo
## 
## Attaching package: 'zoo'
## The following objects are masked from 'package:base':
## 
##     as.Date, as.Date.numeric
## Loading required package: sandwich
ctfit <-ctree(Species ~., data=iris)
summary(ctfit)
##     Length      Class       Mode 
##          1 BinaryTree         S4
# 繪圖參數 參考 party:::plot.BinaryTree
plot(ctfit, type="simple",           # no terminal plots
     inner_panel=node_inner(ctfit,
                            abbreviate = FALSE,    # short variable names
                            pval = TRUE,         # no p-values
                            id = FALSE),        # no id of node
     terminal_panel=node_terminal(ctfit, 
                                  abbreviate = TRUE,
                                  digits = 1,   # few digits on numbers
                                  fill = c("white"),  # make box white not grey
                                  id = FALSE)
)

#
plot(Petal.Width ~Petal.Length, data= iris, col=Species)
abline(v=1.95, col= "black")
abline(h=1.7 , col= "green")
lines(x=rep(4.8, 18), y = seq(0,1.7, 0.1), col='red')

## 觀看分類結果
## type=class, 選擇機率最高類別的當預測結果
table(predict(ctfit, iris[,1:4]), iris[,5])
##             
##              setosa versicolor virginica
##   setosa         50          0         0
##   versicolor      0         49         5
##   virginica       0          1        45

邏輯迴歸分析 分類鳶尾花資料

data(iris)
lr.iris <-iris[(iris$Species!="setosa"),]
lr.iris$Species <- factor(lr.iris$Species)

#廣義線性模型 General Linear Model
lrfit <- glm(Species~., data = lr.iris,family=binomial)
summary(lrfit)
## 
## Call:
## glm(formula = Species ~ ., family = binomial, data = lr.iris)
## 
## Deviance Residuals: 
##      Min        1Q    Median        3Q       Max  
## -2.01105  -0.00541  -0.00001   0.00677   1.78065  
## 
## Coefficients:
##              Estimate Std. Error z value Pr(>|z|)  
## (Intercept)   -42.638     25.707  -1.659   0.0972 .
## Sepal.Length   -2.465      2.394  -1.030   0.3032  
## Sepal.Width    -6.681      4.480  -1.491   0.1359  
## Petal.Length    9.429      4.737   1.991   0.0465 *
## Petal.Width    18.286      9.743   1.877   0.0605 .
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 138.629  on 99  degrees of freedom
## Residual deviance:  11.899  on 95  degrees of freedom
## AIC: 21.899
## 
## Number of Fisher Scoring iterations: 10
lr.predict <- predict(lrfit, lr.iris[,1:4], type='response') 
lr.iris[lr.predict >0.5,'Species']
##  [1] versicolor virginica  virginica  virginica  virginica  virginica 
##  [7] virginica  virginica  virginica  virginica  virginica  virginica 
## [13] virginica  virginica  virginica  virginica  virginica  virginica 
## [19] virginica  virginica  virginica  virginica  virginica  virginica 
## [25] virginica  virginica  virginica  virginica  virginica  virginica 
## [31] virginica  virginica  virginica  virginica  virginica  virginica 
## [37] virginica  virginica  virginica  virginica  virginica  virginica 
## [43] virginica  virginica  virginica  virginica  virginica  virginica 
## [49] virginica  virginica 
## Levels: versicolor virginica
lr.iris$predict <- ifelse(lr.predict >0.5, 'virginica', 'versicolor')


table(lr.iris$predict, lr.iris$Species)
##             
##              versicolor virginica
##   versicolor         49         1
##   virginica           1        49
lrCT.iris <-iris[(iris$Species!="setosa"),]
lrctfit <-ctree(Species ~., data=lrCT.iris)
summary(lrctfit)
##     Length      Class       Mode 
##          1 BinaryTree         S4
table(predict(lrctfit, lrCT.iris[,1:4]), lrCT.iris[,5])
##             
##              setosa versicolor virginica
##   setosa          0          0         0
##   versicolor      0         49         5
##   virginica       0          1        45

支持向量機 分類鳶尾花資料

library(e1071)
data(iris)

irisSVMplotByCost <- function (iris.data, svmcost, xlab = "Sepal.Length" , ylab= 'Sepal.Width' , tlab= 'Species') {
  
  plottitle <- paste('Classify iris by SVM with ', xlab, ylab)

  svm.model = svm(as.formula( paste(tlab , '~.')), data=iris.data, kernel='linear', cost= svmcost, scale=FALSE)
  plot(svm.model, iris.data)
  plot(x=iris.data[[xlab]], y=iris.data[[ylab]], col=iris.data[[tlab]], xlab = xlab, ylab= ylab, main =plottitle,  pch=19)  

  points(iris.data[svm.model$index,c(1,2)],col="blue",cex=2)
  w = t(svm.model$coefs) %*% svm.model$SV
  b = -svm.model$rho
  abline(a=-b/w[1,2], b=-w[1,1]/w[1,2], col="red", lty=5)
  abline(a=(-b-1)/w[1,2], b=-w[1,1]/w[1,2], col="orange", lty=3)
  abline(a=(-b+1)/w[1,2], b=-w[1,1]/w[1,2], col="orange", lty=3)
}


iris.sepal.subset = subset(iris, 
                           select=c("Sepal.Length", "Sepal.Width", "Species"), 
                           Species %in% c("setosa","virginica")) 

irisSVMplotByCost(iris.sepal.subset, svmcost=1)

irisSVMplotByCost(iris.sepal.subset, svmcost=20)

iris.petal.subset = subset(iris, 
                     select=c("Petal.Length", "Petal.Width",  "Species"),
                     Species %in% c('versicolor', 'virginica')) 

irisSVMplotByCost(iris.petal.subset, xlab='Petal.Length', ylab='Petal.Width', svmcost = 1)

irisSVMplotByCost(iris.petal.subset, xlab='Petal.Length', ylab='Petal.Width', svmcost = 20)

使用caret 呼叫不同的機器學習函式

library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
rpartfit <- train(Species ~.  , data = iris, method = "rpart")
#summary(rpartfit)

rffit    <- train(Species ~.  , data = iris, method = "rf")
## Loading required package: randomForest
## randomForest 4.6-12
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
## 
##     margin
## The following object is masked from 'package:dplyr':
## 
##     combine
#summary(rffit)

svmfit   <- train(Species ~.  , data = iris, method = "svmLinear")
## Loading required package: kernlab
## 
## Attaching package: 'kernlab'
## The following object is masked from 'package:ggplot2':
## 
##     alpha
## The following object is masked from 'package:modeltools':
## 
##     prior
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 2,3,10,11,12,14,15,17,22,25,28,29,31,35,37,39,41,45,46,47,56,57,60,63,64,65,68,71,72,74,79,80,82,83,89,91,93,98,102,103,105,106,112,114,115,118,119,122,124,126,133,135,136,138,139,143,146,147,148,150
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 4,5,6,7,8,10,12,16,18,20,22,24,26,28,32,40,43,44,46,49,51,52,56,58,62,63,64,68,69,74,75,76,79,80,81,82,85,86,94,96,98,99,103,105,112,115,117,119,121,125,126,136,140,142,146,148
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 5,12,15,20,23,26,31,33,37,39,42,45,53,54,55,60,61,63,64,66,68,70,72,73,75,81,82,83,87,89,91,92,93,96,101,102,108,113,114,117,120,122,124,127,128,130,132,137,139,140,143,146,147,148,150
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 5,6,8,11,14,18,19,20,21,22,26,29,34,35,41,43,44,46,47,49,51,56,58,59,66,70,71,77,80,84,87,88,89,92,93,95,97,100,106,109,114,116,119,122,124,126,127,130,134,139,140,142,150
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 2,4,5,7,8,13,16,21,22,25,26,27,30,33,35,41,44,47,50,56,58,63,65,66,71,72,74,76,78,80,82,84,85,89,92,94,96,98,100,106,110,114,115,118,120,125,129,131,134,137,139,141,142,150
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 6,8,11,13,18,25,26,29,34,35,41,43,47,48,52,55,57,58,60,61,63,65,67,68,70,73,76,79,80,84,88,90,92,96,101,103,104,105,108,115,121,126,128,130,132,133,135,137,138,141,143,144,145,148,150
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 4,6,9,13,14,19,20,21,23,32,33,38,39,43,45,50,54,57,60,61,63,65,68,72,73,75,77,79,84,85,95,99,101,103,105,114,115,118,119,123,130,131,132,133,137,141,143,144,146,147,149
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 2,3,5,9,10,13,20,21,22,23,24,26,27,28,33,38,40,41,44,50,51,52,65,69,71,73,86,88,90,96,98,101,102,103,108,115,117,118,120,121,124,125,126,129,130,131,134,135,138,141,144,145,147,149
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 5,6,7,10,11,14,15,17,21,27,28,30,33,35,37,40,41,50,54,56,58,61,65,67,69,70,78,83,86,92,100,103,105,107,110,111,113,115,118,120,123,130,132,133,134,138,140,142,143,144,145,148,150
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 4,5,9,10,13,14,15,16,19,21,25,27,30,31,34,37,39,41,43,44,46,49,51,53,54,55,58,59,64,65,68,70,71,74,76,78,87,91,93,98,104,106,110,112,114,129,131,133,134,138,140,143,145,146,147,149,150
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 7,11,13,14,15,20,25,27,30,33,34,39,40,43,45,53,57,59,63,64,65,66,69,70,71,77,78,80,81,83,85,87,88,92,94,102,106,107,109,110,116,118,119,122,125,126,130,132,134,139,140,149
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 3,5,8,9,10,12,13,15,17,19,23,25,28,29,32,39,43,44,46,47,53,57,58,69,72,74,82,85,87,88,91,95,97,99,102,105,106,110,113,117,121,126,128,134,139,141,144,145,147,148,149
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 2,3,5,6,8,11,15,17,24,25,27,35,37,39,41,44,48,50,51,54,56,60,64,66,70,76,80,81,83,84,87,88,91,94,95,98,99,101,102,104,105,108,111,112,116,117,118,120,121,123,129,130,133,136,140,141,142,145,147,148,149
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 2,3,4,7,8,13,17,19,23,29,32,33,35,36,37,41,44,46,48,53,54,55,56,61,67,68,71,72,74,76,77,79,81,82,84,87,92,93,97,101,102,104,106,109,110,118,122,124,126,131,134,136,137,141,146,148,150
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 2,4,5,7,9,11,14,17,18,20,27,30,32,35,39,41,43,47,49,51,54,61,62,64,66,79,81,83,86,87,93,95,98,99,101,103,105,106,107,114,118,123,126,130,132,134,136,137,139,141,144,149,150
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 3,5,10,13,14,16,18,22,26,28,34,35,38,39,40,46,49,53,56,64,65,66,67,68,70,71,72,74,76,81,83,84,86,89,94,95,96,97,99,100,103,104,106,109,110,114,116,117,121,122,128,131,132,134,137,139,141,143,148,149
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 7,8,10,16,19,22,26,27,28,30,33,34,37,38,40,41,44,48,49,50,51,55,56,58,60,61,63,65,70,71,78,80,86,88,90,92,93,95,99,100,102,105,107,110,112,114,118,122,124,125,128,134,142,144,145,147
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 2,3,5,7,8,9,11,15,16,17,20,25,27,29,33,36,37,39,41,43,45,49,51,53,56,57,58,60,62,63,65,66,71,73,75,81,83,85,92,96,97,98,99,106,110,113,114,116,118,120,124,128,129,133,134,135,137,138,140,143,145,148
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 2,4,10,11,14,19,20,22,23,25,27,34,35,37,41,43,48,52,54,56,58,60,63,66,67,69,70,73,77,80,82,84,86,87,89,92,95,97,103,105,109,114,116,118,122,123,125,127,128,129,132,134,136,137,139,142,147,150
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 2,5,12,14,17,18,23,26,27,29,31,47,53,57,60,63,69,70,73,74,78,79,80,82,85,86,88,90,93,95,96,98,102,105,107,109,110,114,117,118,121,124,127,129,132,133,135,138,140,141,144,149,150
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 4,16,17,19,21,22,24,26,28,31,33,39,43,44,47,48,52,53,55,59,62,64,67,71,72,74,77,81,85,87,88,92,94,95,100,102,103,105,108,110,114,115,117,121,123,127,129,132,134,137,139,140,146,147,150
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 2,3,4,6,8,12,15,17,20,25,26,29,32,34,36,37,41,44,45,46,47,48,53,56,60,62,64,66,67,71,75,79,84,88,90,94,101,103,105,106,109,111,113,114,122,127,131,136,137,139,144,146,149
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 2,4,6,7,10,11,12,13,16,17,20,21,31,32,34,38,41,43,45,47,54,55,60,61,62,64,68,70,71,75,77,80,83,84,87,88,89,99,100,102,104,107,111,115,118,129,130,131,133,137,139,144
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 2,5,7,12,18,20,24,25,26,27,30,34,35,39,42,44,47,49,51,55,62,64,65,67,70,72,75,79,81,82,83,84,86,87,90,91,95,98,99,102,105,106,112,113,115,117,118,120,122,124,125,128,130,141,143,144,147,150
## --> row.names NOT used
## Warning in data.row.names(row.names, rowsi, i): some row.names duplicated:
## 3,10,11,14,19,22,24,26,29,30,36,38,40,42,44,47,48,51,52,54,56,58,61,62,67,69,73,75,78,79,80,82,83,84,87,88,91,93,95,104,105,111,113,115,118,121,122,125,126,128,130,132,133,135,138,139,141,146,149,150
## --> row.names NOT used
#summary(svmfit)

顧客流失問題資料敘述

顧客基本資訊

  • state
  • account length.
  • area code
  • phone number

使用者行為

  • international plan
  • voice mail plan, number vmailmessages
  • total day minutes, total day calls, total day charge
  • total eve minutes, total eve calls, total eve charge
  • total night minutes, total night calls, total night charge
  • total intlminutes,total intlcalls, total intlcharge
  • number customer service calls

預測標的

  • Churn (Yes/No )

顧客流失分析

library(C50)
data(churn)
str(churnTrain)
## 'data.frame':    3333 obs. of  20 variables:
##  $ state                        : Factor w/ 51 levels "AK","AL","AR",..: 17 36 32 36 37 2 20 25 19 50 ...
##  $ account_length               : int  128 107 137 84 75 118 121 147 117 141 ...
##  $ area_code                    : Factor w/ 3 levels "area_code_408",..: 2 2 2 1 2 3 3 2 1 2 ...
##  $ international_plan           : Factor w/ 2 levels "no","yes": 1 1 1 2 2 2 1 2 1 2 ...
##  $ voice_mail_plan              : Factor w/ 2 levels "no","yes": 2 2 1 1 1 1 2 1 1 2 ...
##  $ number_vmail_messages        : int  25 26 0 0 0 0 24 0 0 37 ...
##  $ total_day_minutes            : num  265 162 243 299 167 ...
##  $ total_day_calls              : int  110 123 114 71 113 98 88 79 97 84 ...
##  $ total_day_charge             : num  45.1 27.5 41.4 50.9 28.3 ...
##  $ total_eve_minutes            : num  197.4 195.5 121.2 61.9 148.3 ...
##  $ total_eve_calls              : int  99 103 110 88 122 101 108 94 80 111 ...
##  $ total_eve_charge             : num  16.78 16.62 10.3 5.26 12.61 ...
##  $ total_night_minutes          : num  245 254 163 197 187 ...
##  $ total_night_calls            : int  91 103 104 89 121 118 118 96 90 97 ...
##  $ total_night_charge           : num  11.01 11.45 7.32 8.86 8.41 ...
##  $ total_intl_minutes           : num  10 13.7 12.2 6.6 10.1 6.3 7.5 7.1 8.7 11.2 ...
##  $ total_intl_calls             : int  3 3 5 7 3 6 7 6 4 5 ...
##  $ total_intl_charge            : num  2.7 3.7 3.29 1.78 2.73 1.7 2.03 1.92 2.35 3.02 ...
##  $ number_customer_service_calls: int  1 1 0 2 3 0 3 0 1 0 ...
##  $ churn                        : Factor w/ 2 levels "yes","no": 2 2 2 2 2 2 2 2 2 2 ...
churnTrain=churnTrain[,!names(churnTrain)%in%c("state", "area_code", "account_length")]
set.seed(2)
# 將資料分割成兩群(訓練、測試)
ind<-sample(2, nrow(churnTrain), replace=TRUE, prob=c(0.7, 0.3))
trainset=churnTrain[ind==1,]
testset=churnTrain[ind==2,]

library(rpart)
churn.rp<-rpart(churn ~., data=trainset)
plot(churn.rp, margin=0.1)
text(churn.rp, all=TRUE, use.n=TRUE)

predictions <-predict(churn.rp, testset, type="class")
table(testset$churn, predictions)
##      predictions
##       yes  no
##   yes 100  41
##   no   18 859

針對rpart 進行剪枝(pruning)

min(churn.rp$cptable[,"xerror"])
## [1] 0.4707602
which.min(churn.rp$cptable[,"xerror"])
## 7 
## 7
churn.cp=churn.rp$cptable[7,"CP"]
prune.tree=prune(churn.rp, cp=churn.cp)

plot(prune.tree, margin=0.1)
text(prune.tree, all=TRUE, use.n=TRUE)

predictions <-predict(prune.tree, testset, type="class")
table(testset$churn, predictions)
##      predictions
##       yes  no
##   yes  95  46
##   no   14 863
#
aa <- churn.rp$cptable
bb <- aa[-1,]
print(head(aa[, "xerror"], -1) - bb[, "xerror"])
##            1            2            3            4            5 
##  0.002923977  0.236842105  0.233918129  0.005847953  0.011695906 
##            6            7 
##  0.038011696 -0.005847953
churn.cp=churn.rp$cptable[4,"CP"]
prune.tree=prune(churn.rp, cp=churn.cp)

plot(prune.tree, margin=0.1)
text(prune.tree, all=TRUE, use.n=TRUE)

predictions <-predict(prune.tree, testset, type="class")
table(testset$churn, predictions)
##      predictions
##       yes  no
##   yes  88  53
##   no   10 867

使用 ctree 分析

library(party)
ctfit <-ctree(churn ~., data=trainset)
# 繪圖參數 參考 party:::plot.BinaryTree
plot(ctfit, type="simple",           # no terminal plots
     inner_panel=node_inner(ctfit,
                            abbreviate = FALSE,    # short variable names
                            pval = TRUE,         # no p-values
                            id = FALSE),        # no id of node
     terminal_panel=node_terminal(ctfit, 
                                  abbreviate = TRUE,
                                  digits = 1,   # few digits on numbers
                                  fill = c("white"),  # make box white not grey
                                  id = FALSE)
)

進行交叉驗證

Holdout 驗證

隨機從最初的樣本中選出部分,形成交叉驗證數據,而剩餘的就當做訓練數據。一般來說,少於原本樣本 三分之一的數據被選做驗證數據

K-fold cross-validation

K次交叉驗證,初始採樣分割成K個子樣本,一個單獨的子樣本被保留作為驗證模型的數據,其他K-1個樣本用來訓練。交叉驗證重複K次

留一驗證 Leave-One-Out Cross Validation

如名稱所建議,留一驗證(LOOCV)意指只使用原本樣本中的一項來當做驗證資料,而剩餘的則留下來當做訓練資料

 使用CARET 進行 K次交叉驗證

library(caret)
control=trainControl(method="repeatedcv", number=10, repeats=5)
model =train(churn~., data=trainset, method="rpart", preProcess="scale", trControl=control)
model
## CART 
## 
## 2315 samples
##   16 predictor
##    2 classes: 'yes', 'no' 
## 
## Pre-processing: scaled (16) 
## Resampling: Cross-Validated (10 fold, repeated 5 times) 
## Summary of sample sizes: 2084, 2083, 2084, 2084, 2084, 2082, ... 
## Resampling results across tuning parameters:
## 
##   cp          Accuracy   Kappa      Accuracy SD  Kappa SD 
##   0.05555556  0.9017788  0.5254466  0.02102836   0.1181990
##   0.07456140  0.8662667  0.2586761  0.01686002   0.1630487
##   0.07602339  0.8571899  0.1694504  0.01029635   0.1383364
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was cp = 0.05555556.
predictions <- predict(model, testset)
confusionMatrix(table(predictions, testset$churn))     
## Confusion Matrix and Statistics
## 
##            
## predictions yes  no
##         yes  64  10
##         no   77 867
##                                          
##                Accuracy : 0.9145         
##                  95% CI : (0.8957, 0.931)
##     No Information Rate : 0.8615         
##     P-Value [Acc > NIR] : 1.266e-07      
##                                          
##                   Kappa : 0.5527         
##  Mcnemar's Test P-Value : 1.484e-12      
##                                          
##             Sensitivity : 0.45390        
##             Specificity : 0.98860        
##          Pos Pred Value : 0.86486        
##          Neg Pred Value : 0.91843        
##              Prevalence : 0.13851        
##          Detection Rate : 0.06287        
##    Detection Prevalence : 0.07269        
##       Balanced Accuracy : 0.72125        
##                                          
##        'Positive' Class : yes            
## 

找出最重要的變數

library(rminer)
## Loading required package: kknn
## 
## Attaching package: 'kknn'
## The following object is masked from 'package:caret':
## 
##     contr.dummy
## 
## Attaching package: 'rminer'
## The following object is masked from 'package:party':
## 
##     fit
## The following object is masked from 'package:modeltools':
## 
##     fit
model=fit(churn~.,trainset,model="rpart")
VariableImportance=Importance(model,trainset,method="sensv")
L=list(runs=1,sen=t(VariableImportance$imp),sresponses=VariableImportance$sresponses)
mgraph(L,graph="IMP",leg=names(trainset),col="gray",Grid=10)

尋找高相關性變數

new_train=trainset[,!names(churnTrain)%in%c("churn", "international_plan", "voice_mail_plan")]
cor_mat=cor(new_train)
highlyCorrelated=findCorrelation(cor_mat, cutoff=0.75)
names(new_train)[highlyCorrelated]
## [1] "total_intl_minutes"  "total_day_charge"    "total_eve_minutes"  
## [4] "total_night_minutes"

ROC曲線, receiver operating characteristic curve

  1. 以假陽性率(False Positive Rate,
    FPR)為X軸,代表在所有陰性相本中,被判斷為陽性(假陽性)的機率,又寫為(1-特異性)。
  2. 以真陽性率(True Positive Rate, TPR)為Y軸,代表在所有陽性樣本中,被判斷為陽性(真陽性)的機率,又稱為敏感性
data(churn)
churnTrain=churnTrain[,!names(churnTrain)%in%c("state", "area_code", "account_length")]
set.seed(2)
# 將資料分割成兩群(訓練、測試)
ind<-sample(2, nrow(churnTrain), replace=TRUE, prob=c(0.7, 0.3))
trainset=churnTrain[ind==1,]
testset=churnTrain[ind==2,]

library(rpart)
churn.rp<-rpart(churn ~., data=trainset)
plot(churn.rp, margin=0.1)
text(churn.rp, all=TRUE, use.n=TRUE)

predictions <-predict(churn.rp, testset)

res <- data.frame()
# 指定不同閥值,計算 TPR, FPR
for(cost in seq(0,1,0.05)) {
  predict_result <- ifelse(predictions[,1] >= cost, 'yes', 'no' )
  cm <- confusionMatrix(factor(predict_result, levels= c('yes','no')), testset$churn)
  sens <- cm$byClass[1]
  spec <- cm$byClass[2]
  rcp <- data.frame(FPR = 1- spec, TPR = sens)
  res <- rbind(res, rcp)
}
res <- rbind(res, data.frame(FPR= 0, TPR= 0))
plot(TPR~ FPR, data = res, col = 'red', type= 'l', xlim=c(0,1), ylim=c(0,1))
points(TPR~ FPR, data = res, col= 'blue')

使用ROCR繪製 ROC曲線

library(ROCR)
## Loading required package: gplots
## 
## Attaching package: 'gplots'
## The following object is masked from 'package:stats':
## 
##     lowess
predictions <- predict(churn.rp, testset, type="prob")
pred.to.roc <- predictions[, 1]
pred.rocr <- prediction(pred.to.roc, as.factor(testset[,(dim(testset)[[2]])]))
perf.rocr <- performance(pred.rocr, measure ="auc", x.measure="cutoff")
perf.tpr.rocr <- performance(pred.rocr, "tpr","fpr")
plot(perf.tpr.rocr, colorize=T,main=paste("AUC:",(perf.rocr@y.values)))

聚合式分群

hclust 計算效能不好,資料集大,不合適使用

if(file.exists('D:/CXLWorkspace/RS/RML/customer.csv')) {
   download.file('https://raw.githubusercontent.com/ywchiu/rtibame/master/data/customer.csv', 'D:/CXLWorkspace/RS/RML/customer.csv')
}
     customer=read.csv('D:/CXLWorkspace/RS/RML/customer.csv', header=TRUE)
     head(customer)
##   ID Visit.Time Average.Expense Sex Age
## 1  1          3             5.7   0  10
## 2  2          5            14.5   0  27
## 3  3         16            33.5   0  32
## 4  4          5            15.9   0  30
## 5  5         16            24.9   0  23
## 6  6          3            12.0   0  15
     str(customer)
## 'data.frame':    60 obs. of  5 variables:
##  $ ID             : int  1 2 3 4 5 6 7 8 9 10 ...
##  $ Visit.Time     : int  3 5 16 5 16 3 12 14 6 3 ...
##  $ Average.Expense: num  5.7 14.5 33.5 15.9 24.9 12 28.5 18.8 23.8 5.3 ...
##  $ Sex            : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ Age            : int  10 27 32 30 23 15 33 27 16 11 ...
     customer =scale(customer[,-1])
     hc=hclust(dist(customer, method="euclidean"), method="ward.D2")
     plot(hc, hang =-0.01, cex=0.7)

更改距離函式

hc2 =hclust(dist(customer), method="single")
plot(hc2, hang =-0.01, cex=0.7)

使用KMEAN 分群

set.seed(22)
fit =kmeans(customer, 4)
barplot(t(fit$centers), beside =TRUE,xlab="cluster", ylab="value")

繪製分群結果

plot(Visit.Time~ Average.Expense, customer, col=fit$cluster)

library(cluster)
clusplot(customer, fit$cluster, color=TRUE, shade=TRUE)

評估分群效果

set.seed(22)
km =kmeans(customer, 4)
kms=silhouette(km$cluster,dist(customer))
summary(kms)
## Silhouette of 60 units in 4 clusters from silhouette.default(x = km$cluster, dist = dist(customer)) :
##  Cluster sizes and average silhouette widths:
##         8        11        16        25 
## 0.5464597 0.4080823 0.3794910 0.5164434 
## Individual silhouette widths:
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##  0.1931  0.4030  0.4890  0.4641  0.5422  0.6333
plot(kms)

如何決定K值(WSS),找出斜率改變趨緩的轉折點

     nk=2:10
     set.seed(22)
     WSS =sapply(nk, function(k){
       kmeans(customer, centers=k)$tot.withinss
       })
     WSS
## [1] 123.49224  88.07028  61.34890  48.76431  47.20813  45.48114  29.58014
## [8]  28.87519  23.21331
     plot(nk, WSS, type="l", xlab="number of k", ylab="within sum of squares")

比較不同分群演算法

1. [Overview of clustering methods](http://scikit-learn.org/stable/modules/clustering.html)
1. [DBSCAN](https://books.google.com.tw/books?id=iPirBwAAQBAJ&pg=PA318&lpg=PA318&dq=machine+learning+with+R+cookbook+DBSCAN&source=bl&ots=_U-4DDdJn9&sig=lEhbcMew3SHkdFfbb2ZEeG43AEI&hl=zh-TW&sa=X&ved=0ahUKEwjin_zrus3QAhWFFpQKHWd8AJYQ6AEIGTAA#v=onepage&q=machine%20learning%20with%20R%20cookbook%20DBSCAN&f=false)
   single_c=hclust(dist(customer), method="single")
     
     hc_single=cutree(single_c, k =4)
     
     complete_c=hclust(dist(customer), method="complete")
     
     hc_complete=cutree(complete_c, k =4)
     
     set.seed(22)
     km =kmeans(customer, 4)
     library(fpc)
     cs <- cluster.stats(dist(customer), km$cluster)
     cs[c("within.cluster.ss","avg.silwidth")]
## $within.cluster.ss
## [1] 61.3489
## 
## $avg.silwidth
## [1] 0.4640587
     sapply(list(kmeans=km$cluster, hc_single=hc_single, hc_complete=hc_complete), 
            function(c) {
              cluster.stats(dist(customer), c)[c("within.cluster.ss","avg.silwidth")] 
            } )    
##                   kmeans    hc_single hc_complete
## within.cluster.ss 61.3489   136.0092  65.94076   
## avg.silwidth      0.4640587 0.2481926 0.4255961

實際分析: 匯入資料,進行K-means 分析

data(iris)
data<-iris[,-5]
class<-iris[,5]
results <-kmeans(data,3)
results$size
## [1] 62 38 50
results$cluster
##   [1] 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
##  [36] 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
##  [71] 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 2 2 2
## [106] 2 1 2 2 2 2 2 2 1 1 2 2 2 2 1 2 1 2 1 2 2 1 1 2 2 2 2 2 1 2 2 2 2 1 2
## [141] 2 2 1 2 2 2 1 2 2 1

實際分析: 分析分群結果

table(class,results$cluster)
##             
## class         1  2  3
##   setosa      0  0 50
##   versicolor 48  2  0
##   virginica  14 36  0
par(mfrow=c(2, 2))
plot(data$Petal.Length,data$Petal.Width,col=results$cluster, main='With Petal color by Cluster')
plot(data$Petal.Length,data$Petal.Width,col=class, main='With Petal color by Species')
plot(data$Sepal.Length, data$Sepal.Width,col=results$cluster,  main='With Sepal color by Cluster')
plot(data$Sepal.Length, data$Sepal.Width,col=class,  main='With Sepal color by Species')

par(mfrow=c(1, 1))

FSelector 特徵排序

library(FSelector)
weights=random.forest.importance(churn~., trainset, importance.type=1)

library(dplyr)
order.weights <- weights %>%
  arrange(desc(attr_importance) )

row.names(order.weights) <- row.names(weights)[order(weights, decreasing = TRUE)]
print(order.weights)
##                               attr_importance
## number_customer_service_calls     110.4221189
## international_plan                102.8357903
## total_intl_calls                   52.5283445
## total_day_minutes                  51.1937660
## total_day_charge                   50.7530737
## total_intl_charge                  34.2814776
## total_eve_charge                   33.2878651
## total_eve_minutes                  33.1804801
## total_intl_minutes                 32.3869221
## number_vmail_messages              30.1824698
## voice_mail_plan                    26.7166872
## total_night_charge                 23.4311731
## total_night_minutes                22.6066758
## total_night_calls                   2.8223601
## total_day_calls                    -0.7725924
## total_eve_calls                    -1.9159871
subset=cutoff.k(weights, 5)
f =as.simple.formula(subset, "Class")
print(f)
## Class ~ number_customer_service_calls + international_plan + 
##     total_intl_calls + total_day_minutes + total_day_charge
## <environment: 0x0000000043a7d3d8>

使用caret 套件排序特徵

library(caret)
control=trainControl(method="repeatedcv", number=10, repeats=3)
model =train(churn~., data=trainset, method="rpart", preProcess="scale", trControl=control)
importance =varImp(model, scale=FALSE)
importance
## rpart variable importance
## 
##                               Overall
## number_customer_service_calls 116.015
## total_day_minutes             106.988
## total_day_charge              100.648
## international_planyes          86.789
## voice_mail_planyes             25.974
## total_eve_charge               23.097
## total_eve_minutes              23.097
## number_vmail_messages          19.885
## total_intl_minutes              6.347
## total_eve_calls                 0.000
## total_day_calls                 0.000
## total_night_calls               0.000
## total_intl_charge               0.000
## total_night_minutes             0.000
## total_intl_calls                0.000
## total_night_charge              0.000
plot(importance)

使用 rminer

library(rminer)
model=fit(churn~.,trainset,model="rpart")
VariableImportance=Importance(model,trainset,method="sensv")
L=list(runs=1,sen=t(VariableImportance$imp),sresponses=VariableImportance$sresponses)
mgraph(L,graph="IMP",leg=names(trainset),col="gray",Grid=10)

找出辨識率最高的子集合

evaluator =function(subset){
  k =5
  set.seed(2)
  ind= sample(5, nrow(trainset), replace=TRUE)
  results =sapply(1:k, function(i){
    train =trainset[ind==i,]
    test =trainset[ind!=i,]
    tree =rpart(as.simple.formula(subset, "churn"), trainset)
    error.rate=sum(test$churn!=predict(tree, test, type="class"))/nrow(test)
    return(1-error.rate)
  })
  return(mean(results))
}

attr.subset <- hill.climbing.search(names(trainset)[!names(trainset)%in%"churn"], evaluator)
f <- as.simple.formula(attr.subset, "churn")
print(f)
## churn ~ international_plan + voice_mail_plan + total_day_calls + 
##     total_day_charge + total_eve_minutes + total_eve_calls + 
##     total_eve_charge + total_intl_minutes + total_intl_calls + 
##     total_intl_charge + number_customer_service_calls
## <environment: 0x000000004090b490>

PCA主成分分析

data(swiss)
swiss=swiss[,-1]
swiss.pca=prcomp(swiss,center =TRUE,scale=TRUE)
swiss.pca
## Standard deviations:
## [1] 1.6228065 1.0354873 0.9033447 0.5592765 0.4067472
## 
## Rotation:
##                          PC1         PC2          PC3        PC4
## Agriculture       0.52396452 -0.25834215  0.003003672 -0.8090741
## Examination      -0.57185792 -0.01145981 -0.039840522 -0.4224580
## Education        -0.49150243  0.19028476  0.539337412 -0.3321615
## Catholic          0.38530580  0.36956307  0.725888143  0.1007965
## Infant.Mortality  0.09167606  0.87197641 -0.424976789 -0.2154928
##                          PC5
## Agriculture       0.06411415
## Examination      -0.70198942
## Education         0.56656945
## Catholic         -0.42176895
## Infant.Mortality  0.06488642
summary(swiss.pca)
## Importance of components:
##                           PC1    PC2    PC3     PC4     PC5
## Standard deviation     1.6228 1.0355 0.9033 0.55928 0.40675
## Proportion of Variance 0.5267 0.2145 0.1632 0.06256 0.03309
## Cumulative Proportion  0.5267 0.7411 0.9043 0.96691 1.00000
predict(swiss.pca, newdata=head(swiss, 1))
##                   PC1       PC2        PC3      PC4       PC5
## Courtelary -0.9390479 0.8047122 -0.8118681 1.000307 0.4618643
## 碎石圖
screeplot(swiss.pca, type="barplot")

screeplot(swiss.pca, type="line")

## 判斷主成分個數
swiss.pca$sdev
## [1] 1.6228065 1.0354873 0.9033447 0.5592765 0.4067472
swiss.pca$sdev^2
## [1] 2.6335008 1.0722340 0.8160316 0.3127902 0.1654433
which(swiss.pca$sdev^2>1)
## [1] 1 2
screeplot(swiss.pca, type="line")
abline(h=1, col="red", lty=3)

## PCA 雙邊圖
plot(swiss.pca$x[,1], swiss.pca$x[,2], xlim=c(-4,4))
text(swiss.pca$x[,1], swiss.pca$x[,2], rownames(swiss.pca$x), cex=0.7, pos=4, col="red")

biplot(swiss.pca)

降低維度應用: 編列幸福指數指標

  • 縣市
  • 營利事業銷售額
  • 經濟發展支出佔歲出比例
  • 得收入者平均每人可支配所得
if(!file.exists('D:/CXLWorkspace/RS/RML/eco_index.csv')) {
    download.file('https://raw.githubusercontent.com/ywchiu/riii/master/data/eco_index.csv', 'D:/CXLWorkspace/RS/RML/eco_index.csv')  
}


dataset <-read.csv('D:/CXLWorkspace/RS/RML/eco_index.csv',head=TRUE, sep=',', row.names=1)
pc.cr <-princomp(dataset, cor=TRUE)
plot(pc.cr)

## 繪製碎石圖
screeplot(pc.cr, type="lines")
abline(h=1, lty=3)

## PCA 雙邊圖
biplot(pc.cr)

## 繪製直方圖
barplot(sort(-pc.cr$scores[,1], TRUE))

PCA 影像壓縮

[Lenna](https://zh.wikipedia.org/wiki/%E8%90%8A%E5%A8%9C%E5%9C%96)
[Image Compression with PCA](https://ryancquan.com/blog/2014/10/07/image-compression-pca/)