required_pkgs = c("tidyverse", "mlbench", "rpart", "rpart.plot", "caret", "Metrics")
to_install = required_pkgs[!required_pkgs %in% installed.packages()[, "Package"]]
if (length(to_install) > 0) install.packages(to_install)

library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## āœ” dplyr     1.1.4     āœ” readr     2.1.5
## āœ” forcats   1.0.0     āœ” stringr   1.5.1
## āœ” ggplot2   4.0.0     āœ” tibble    3.3.0
## āœ” lubridate 1.9.4     āœ” tidyr     1.3.1
## āœ” purrr     1.1.0     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## āœ– dplyr::filter() masks stats::filter()
## āœ– dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(mlbench)
library(rpart)
library(rpart.plot)
library(caret)
## Loading required package: lattice
## 
## Attaching package: 'caret'
## The following object is masked from 'package:purrr':
## 
##     lift
library(Metrics)
## 
## Attaching package: 'Metrics'
## The following objects are masked from 'package:caret':
## 
##     precision, recall

CART (Classification and Regression Tree)

CART Regression Tree

Menggunakan package mlbench yang berisi dataset BostonHousing karakteristik lingkungan perumahan di area Boston, AS dengan variabel target berupa nilai numerik kontinu.

Tujuan Pemodelan: memprediksi medv median value rumah berdasarkan fitur sosial-ekonomi dan kondisi lingkungan

Variabel respon/ target

  • medv: nilai tengah (median) harga rumah untuk suatu wilayah/area di Boston.
    Karena targetnya kontinu, maka model yang tepat untuk sesi ini adalah pohon regresi (CART Regression).

Variabel Prediktor/ Penjelas

  • crim: tingkat kriminalitas per kapita

  • zn: proporsi lahan hunian untuk lot besar

  • indus: proporsi luas untuk bisnis non-ritel

  • nox: konsentrasi nitrogen oxides

  • rm: rata-rata jumlah kamar per rumah

  • age: proporsi unit hunian tua

  • dis: jarak ke pusat pekerjaan

  • rad: aksesibilitas ke jalan raya

  • tax: pajak properti

  • ptratio: rasio murid–guru

  • lstat: persentase populasi berstatus sosial lebih rendah

  • b: indeks yang terkait komposisi populasi (variabel historis)

  • chas: indikator apakah lokasi berbatasan dengan Charles River (0/1).

    Variabel seperti ini sering diuji dalam EDA karena bisa memisahkan distribusi target antar-kelompok (misal perbedaan medv untuk chas=0 vs chas=1).

Dalam analisis ini saya akan mengambil peubah bebasnya adalah chas, lstat, crim, indus, rad, dan tax.

Dataset ini cocok untuk latihan CART Regression karena:

  1. medv bersifat kontinu sesuai dengan pohon regresi.
  2. Fitur campuran (numerik+biner), tree dapat menangani keduanya secara natural.
  3. Hubungan fitur-targer sering nonlinear dan melibatkan interaksi, tree dapat mengatasi ini tanpa perlu fungsi eksplisit.
data("BostonHousing", package = "mlbench")
df = BostonHousing
df = df %>% 
  select(medv, chas, lstat, crim, indus, rad, tax)

# Atau menggunakan Base R (tanpa library tambahan)
# df = df[, c("chas", "lstat", "crim", "indus", "rad", "tax")]

glimpse(df)
## Rows: 506
## Columns: 7
## $ medv  <dbl> 24.0, 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15.0…
## $ chas  <fct> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
## $ lstat <dbl> 4.98, 9.14, 4.03, 2.94, 5.33, 5.21, 12.43, 19.15, 29.93, 17.10, …
## $ crim  <dbl> 0.00632, 0.02731, 0.02729, 0.03237, 0.06905, 0.02985, 0.08829, 0…
## $ indus <dbl> 2.31, 7.07, 7.07, 2.18, 2.18, 2.18, 7.87, 7.87, 7.87, 7.87, 7.87…
## $ rad   <dbl> 1, 2, 2, 3, 3, 3, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4…
## $ tax   <dbl> 296, 242, 242, 222, 222, 222, 311, 311, 311, 311, 311, 311, 311,…

EDA

summary(df$medv)   # target regresi (median value)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##    5.00   17.02   21.20   22.53   25.00   50.00
dim(df)
## [1] 506   7

Dataset memiliki 506 baris (observasi) dan 7 kolom (variabel)

summary(df)
##       medv       chas        lstat            crim              indus      
##  Min.   : 5.00   0:471   Min.   : 1.73   Min.   : 0.00632   Min.   : 0.46  
##  1st Qu.:17.02   1: 35   1st Qu.: 6.95   1st Qu.: 0.08205   1st Qu.: 5.19  
##  Median :21.20           Median :11.36   Median : 0.25651   Median : 9.69  
##  Mean   :22.53           Mean   :12.65   Mean   : 3.61352   Mean   :11.14  
##  3rd Qu.:25.00           3rd Qu.:16.95   3rd Qu.: 3.67708   3rd Qu.:18.10  
##  Max.   :50.00           Max.   :37.97   Max.   :88.97620   Max.   :27.74  
##       rad              tax       
##  Min.   : 1.000   Min.   :187.0  
##  1st Qu.: 4.000   1st Qu.:279.0  
##  Median : 5.000   Median :330.0  
##  Mean   : 9.549   Mean   :408.2  
##  3rd Qu.:24.000   3rd Qu.:666.0  
##  Max.   :24.000   Max.   :711.0
  • variabel chas merupakan peubah biner dengan nilai 0 (area tidak berbatasan sungai) dan 1 (area berbatasan sungai)

  • medv (nilai median rumah per area) memiliki nilai min. 5 dan max. 50, artinya nilai termurah adalah sebesar 5 dan termahal adalah sebesar 50 (satuan)

# cek missing value 
colSums(is.na(df))
##  medv  chas lstat  crim indus   rad   tax 
##     0     0     0     0     0     0     0

Semua peubah tidak memiliki missing value, artinnya data sudah bersih.

# Histogram & Boxplot: cek distribusi variabel respon 
hist(df$medv, 
     main = "Distribusi medv", 
     xlab = "medv",
     col = "skyblue")

boxplot(df$medv, 
        horizontal = TRUE, 
        main = 'Boxplot medv',
        col = "salmon")

  • Pada Histogram terlihat disrtibusi dari medv menjulur ke kanan (right-skewes) dengan nilai minimum 5 dan maksimum 50.

  • Pada Boxplot terlihat median medv ada di sekitar 21-22, terlihat banyak outlier yang muncul di sisi kanan, artinya ada sejumlah rumah yang nilainya lebih dari rata-rata.

# cek korelasi 
library (corrplot)
## corrplot 0.95 loaded
num_df = df %>% dplyr::select(where(is.numeric))
corr_mat = cor(num_df)

corrplot(
  corr_mat,
  method = "color",
  type = "upper",
  order = "hclust",
  addCoef.col = "black",
  tl.cex = 0.8,
  number.cex = 0.6
)

Hasil korelasi pearson menggunakan corrplot menunjukkan bahwa

Hubungan Positif: Tidak ada hubungan korelasi positif di antara semua peubah prediktor terhadap peubah respon.

Hubungan Negatif:

  • Sangat Kuat : korelasi negatif sangat kuat terjadi antara lstat dengan medv dengan nilai 0.74.

  • cukup kuat: korelasi negatif cukup kuat terjadi antara indus-medv dan tax-medv dengan nilai -0.48 dan -0.47.

  • lemah: korelasi negatif lemah terjadi antara crim-medv dan rad-medv dengan nilai -0.39 dan -0.38.

# distribusi medv untuk kategori chas 
boxplot(medv ~ chas, 
        data = df, 
        main = "medv per kategori chas", 
        xlab = "chas", 
        yab = "medv",
        col = c("skyblue", "orange"))

Grafik Boxplot di atas menunjukkan distribusi peubah respon medv untuk dua kategori chas. Ketika chas = 1 artinya berbatasan dengan Charles River dan ketika chas = 0 artinya tidak berbatasan dengan Charles River.

  • median pada boxplot dengan chas = 1 lebih tinggi daripada boxplot dengan chas = 0.

  • artinya harga rumah yang dekat dengan sungai Charles cenderung lebih mahal.

  • Pada boxplot chas = 0 memiliki outliers yang cukup banyak, sehingga walaupun rata-rata harga rumah di dekat sungai Charles memang lebih mahal, tetapi masih ada rumah yang juga mahal di area yang jauh dengan sungai Charles.

Splitting Data

set.seed(123)
idx = createDataPartition(df$medv, p = 0.8, list = FALSE)
train = df[idx, ]
test = df[-idx, ]

dim(train); dim(test)
## [1] 407   7
## [1] 99  7
  • sebanyak 407 amatan adalah data latih.

  • sebanyak 99 amatan adalah data uji.

fit = rpart(
  medv ~ .,
  data = train,
  method = "anova",     # kunci untuk regresi
  control = rpart.control(
    minsplit = 20,
    minbucket = 7,
    maxdepth = 30,
    cp = 0              # grow dulu
  )
)

fit
## n= 407 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##   1) root 407 34125.58000 22.510570  
##     2) lstat>=9.63 241  5710.28100 17.429460  
##       4) lstat>=16.085 120  1981.15000 14.300830  
##         8) crim>=5.76921 62   625.79870 11.925810  
##          16) lstat>=19.73 41   388.98880 10.868290  
##            32) lstat>=26.425 15   163.90930  9.506667 *
##            33) lstat< 26.425 26   181.22460 11.653850  
##              66) lstat< 21.15 7    56.84857 10.214290 *
##              67) lstat>=21.15 19   104.52530 12.184210 *
##          17) lstat< 19.73 21   101.43810 13.990480  
##            34) lstat< 18.09 14    56.85714 13.485710 *
##            35) lstat>=18.09 7    33.88000 15.000000 *
##         9) crim< 5.76921 58   631.77880 16.839660  
##          18) crim>=0.65402 29   178.65860 15.493100  
##            36) tax< 551.5 18    33.38278 14.461110 *
##            37) tax>=551.5 11    94.73636 17.181820 *
##          19) crim< 0.65402 29   347.95450 18.186210  
##            38) tax>=434.5 9    77.96222 14.855560 *
##            39) tax< 434.5 20   125.22550 19.685000  
##              78) tax< 276.5 7    35.61714 18.357140 *
##              79) tax>=276.5 13    70.62000 20.400000 *
##       5) lstat< 16.085 121  1389.64400 20.532230  
##        10) tax>=281.5 94   928.59320 19.908510  
##          20) lstat>=11.675 65   551.31780 19.018460  
##            40) lstat>=14.395 25   231.80960 17.696000  
##              80) crim>=2.842095 9   113.07560 15.522220 *
##              81) crim< 2.842095 16    52.28438 18.918750 *
##            41) lstat< 14.395 40   248.45900 19.845000  
##              82) tax< 417.5 23    85.85652 19.043480  
##               164) lstat>=12.465 13    35.32923 18.592310 *
##               165) lstat< 12.465 10    44.44100 19.630000 *
##              83) tax>=417.5 17   127.83530 20.929410 *
##          21) lstat< 11.675 29   210.36970 21.903450  
##            42) crim< 0.279605 14    45.61429 20.657140 *
##            43) crim>=0.279605 15   122.71330 23.066670 *
##        11) tax< 281.5 27   297.16960 22.703700  
##          22) indus>=4.23 20    68.07000 21.550000  
##            44) tax< 240 9    10.47556 20.077780 *
##            45) tax>=240 11    22.12727 22.754550 *
##          23) indus< 4.23 7   126.42000 26.000000 *
##     3) lstat< 9.63 166 13160.04000 29.887350  
##       6) lstat>=5.185 111  4661.56300 26.132430  
##        12) crim< 0.39079 89  1843.28000 24.802250  
##          24) tax>=223.5 81  1254.95400 24.160490  
##            48) lstat>=7.475 37   437.80430 22.264860  
##              96) crim< 0.087825 25   282.16240 21.252000  
##               192) crim>=0.047125 8    94.15500 19.825000 *
##               193) crim< 0.047125 17   164.05060 21.923530 *
##              97) crim>=0.087825 12    76.56250 24.375000 *
##            49) lstat< 7.475 44   572.38910 25.754550  
##              98) indus>=4.895 25   166.53360 24.516000  
##               196) rad>=4.5 13    14.40769 23.230770 *
##               197) rad< 4.5 12   107.38920 25.908330 *
##              99) indus< 4.895 19   317.04530 27.384210 *
##          25) tax< 223.5 8   217.20000 31.300000 *
##        13) crim>=0.39079 22  2023.74600 31.513640  
##          26) indus>=7.17 12  1166.50900 28.391670 *
##          27) indus< 7.17 10   599.92400 35.260000 *
##       7) lstat< 5.185 55  3774.92400 37.465450  
##        14) crim< 0.51938 46  2589.78400 35.523910  
##          28) lstat>=4.475 19   464.08110 31.031580 *
##          29) lstat< 4.475 27  1472.43400 38.685190  
##            58) tax>=248.5 19   841.69680 37.126320 *
##            59) tax< 248.5 8   474.90880 42.387500 *
##        15) crim>=0.51938 9   125.46890 47.388890 *

Output menampilkan:

  • n = jumlah data training yang dipakai,

  • split = aturan pemisah pada node,

  • n (di tiap node) = jumlah data yang masuk node,

  • deviance ā‰ˆ total RSS/SSE dalam node,

  • yval = prediksi node (rata-rata medv),

  • tanda * menunjukkan node terminal (leaf).

Pembuatan Model CART

fit = rpart(
  medv ~ .,
  data = train,
  method = "anova",     # kunci untuk regresi
  control = rpart.control(
    minsplit = 20,
    minbucket = 7,
    maxdepth = 30,
    cp = 0              # grow dulu
  )
)

fit
## n= 407 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##   1) root 407 34125.58000 22.510570  
##     2) lstat>=9.63 241  5710.28100 17.429460  
##       4) lstat>=16.085 120  1981.15000 14.300830  
##         8) crim>=5.76921 62   625.79870 11.925810  
##          16) lstat>=19.73 41   388.98880 10.868290  
##            32) lstat>=26.425 15   163.90930  9.506667 *
##            33) lstat< 26.425 26   181.22460 11.653850  
##              66) lstat< 21.15 7    56.84857 10.214290 *
##              67) lstat>=21.15 19   104.52530 12.184210 *
##          17) lstat< 19.73 21   101.43810 13.990480  
##            34) lstat< 18.09 14    56.85714 13.485710 *
##            35) lstat>=18.09 7    33.88000 15.000000 *
##         9) crim< 5.76921 58   631.77880 16.839660  
##          18) crim>=0.65402 29   178.65860 15.493100  
##            36) tax< 551.5 18    33.38278 14.461110 *
##            37) tax>=551.5 11    94.73636 17.181820 *
##          19) crim< 0.65402 29   347.95450 18.186210  
##            38) tax>=434.5 9    77.96222 14.855560 *
##            39) tax< 434.5 20   125.22550 19.685000  
##              78) tax< 276.5 7    35.61714 18.357140 *
##              79) tax>=276.5 13    70.62000 20.400000 *
##       5) lstat< 16.085 121  1389.64400 20.532230  
##        10) tax>=281.5 94   928.59320 19.908510  
##          20) lstat>=11.675 65   551.31780 19.018460  
##            40) lstat>=14.395 25   231.80960 17.696000  
##              80) crim>=2.842095 9   113.07560 15.522220 *
##              81) crim< 2.842095 16    52.28438 18.918750 *
##            41) lstat< 14.395 40   248.45900 19.845000  
##              82) tax< 417.5 23    85.85652 19.043480  
##               164) lstat>=12.465 13    35.32923 18.592310 *
##               165) lstat< 12.465 10    44.44100 19.630000 *
##              83) tax>=417.5 17   127.83530 20.929410 *
##          21) lstat< 11.675 29   210.36970 21.903450  
##            42) crim< 0.279605 14    45.61429 20.657140 *
##            43) crim>=0.279605 15   122.71330 23.066670 *
##        11) tax< 281.5 27   297.16960 22.703700  
##          22) indus>=4.23 20    68.07000 21.550000  
##            44) tax< 240 9    10.47556 20.077780 *
##            45) tax>=240 11    22.12727 22.754550 *
##          23) indus< 4.23 7   126.42000 26.000000 *
##     3) lstat< 9.63 166 13160.04000 29.887350  
##       6) lstat>=5.185 111  4661.56300 26.132430  
##        12) crim< 0.39079 89  1843.28000 24.802250  
##          24) tax>=223.5 81  1254.95400 24.160490  
##            48) lstat>=7.475 37   437.80430 22.264860  
##              96) crim< 0.087825 25   282.16240 21.252000  
##               192) crim>=0.047125 8    94.15500 19.825000 *
##               193) crim< 0.047125 17   164.05060 21.923530 *
##              97) crim>=0.087825 12    76.56250 24.375000 *
##            49) lstat< 7.475 44   572.38910 25.754550  
##              98) indus>=4.895 25   166.53360 24.516000  
##               196) rad>=4.5 13    14.40769 23.230770 *
##               197) rad< 4.5 12   107.38920 25.908330 *
##              99) indus< 4.895 19   317.04530 27.384210 *
##          25) tax< 223.5 8   217.20000 31.300000 *
##        13) crim>=0.39079 22  2023.74600 31.513640  
##          26) indus>=7.17 12  1166.50900 28.391670 *
##          27) indus< 7.17 10   599.92400 35.260000 *
##       7) lstat< 5.185 55  3774.92400 37.465450  
##        14) crim< 0.51938 46  2589.78400 35.523910  
##          28) lstat>=4.475 19   464.08110 31.031580 *
##          29) lstat< 4.475 27  1472.43400 38.685190  
##            58) tax>=248.5 19   841.69680 37.126320 *
##            59) tax< 248.5 8   474.90880 42.387500 *
##        15) crim>=0.51938 9   125.46890 47.388890 *

Output di atas menunjukkan hasil dari splitting di mana:

node) = nomor identitas node.

split = kondisi pemisahan.

n = jumlah observasi dalam node tersebut

deviance = RSS (Residual Sum of Squares)/JKG (Jumlah Kuadrat Galat) semakin kecil nilainya artinya semakin seragam variance dari node tersebut.

yval = nilai prediksi, ini adalah nilai rata-rata dari medv dari n observasi tersebut

*denotes terminal node = terminal node, artinya tidak ada percabangan lagi di bawahnya dan nilai tersebut merupakan nilai prediksi akhir.

  • Splitting ini dilakukan sebelum pruning karena cp = 0 , sehingga terlihat pohon terlalu ā€œrimbunā€ dan detil.

  • Sehingga dalam node ada yang hanya berisi 7 observasi (terlalu detil).

  • Namun sangat rentan terjadinya overfitting (model terlalu cocok dengan data latih, tetapi saat dilakukan data baru akan gagal).

  • Maka dibutuhkan pruning untuk mendapatkan hasil dan interpretasi yang lebih sederhana dan akurat.

Visualisasi CART

rpart.plot(
  fit,
  type = 2,
  extra = 1,
  fallen.leaves = TRUE,
  tweak = 1.0
)

sesuai dengan interpretasi sebelumnya, visualisasi pohon CART terlalu rimbun dan detil seperti pada grafik di atas. Sehingga dibutuhkan pruning.

Pruning

cp_table = fit$cptable

min_xerr = min(cp_table[, "xerror"])
min_row  = which.min(cp_table[, "xerror"])
xstd_min = cp_table[min_row, "xstd"]

# ambil pohon paling sederhana yang masih dalam 1 std dari minimum
best_row = which(cp_table[, "xerror"] <= (min_xerr + xstd_min))[1]
best_cp  = cp_table[best_row, "CP"]

fit_pruned = prune(fit, cp = best_cp)
best_cp
## [1] 0.007540171
# Visualisasi tree hasill pruning
fit_pruned = prune(fit, cp = best_cp)

rpart.plot(
  fit_pruned,
  type = 2,
  extra = 1,
  fallen.leaves = TRUE,
  tweak = 1.0
)

fit_pruned
## n= 407 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 407 34125.5800 22.51057  
##    2) lstat>=9.63 241  5710.2810 17.42946  
##      4) lstat>=16.085 120  1981.1500 14.30083  
##        8) crim>=5.76921 62   625.7987 11.92581 *
##        9) crim< 5.76921 58   631.7788 16.83966 *
##      5) lstat< 16.085 121  1389.6440 20.53223 *
##    3) lstat< 9.63 166 13160.0400 29.88735  
##      6) lstat>=5.185 111  4661.5630 26.13243  
##       12) crim< 0.39079 89  1843.2800 24.80225  
##         24) tax>=223.5 81  1254.9540 24.16049 *
##         25) tax< 223.5 8   217.2000 31.30000 *
##       13) crim>=0.39079 22  2023.7460 31.51364 *
##      7) lstat< 5.185 55  3774.9240 37.46545  
##       14) crim< 0.51938 46  2589.7840 35.52391  
##         28) lstat>=4.475 19   464.0811 31.03158 *
##         29) lstat< 4.475 27  1472.4340 38.68519 *
##       15) crim>=0.51938 9   125.4689 47.38889 *

Hasil pruning menunjukkan setelah pemangkasan terlihat hanya menghasilkan cabang/pohon yang jauh lebih sedikit.

  • lstat (ekonomi lingkungan) memberikan faktor yang cukup dominan dalam penentuan harga

  • crim (kriminalitas) juga menjadi salah satu faktor penentu harga baik dalam menurunkan harga yang sangat ektrem dan menaikan harga sehingga sangat mahal

    8) crim>=5.76921 62   625.7987 11.92581 *        
    9) crim< 5.76921 58   631.7788 16.83966 *
    ketika tingkat kriminalitas tinggi, maka harganya pun lebih rendah
     15) crim>=0.51938 9   125.4689 47.38889 * 
     12) crim< 0.39079 89  1843.2800 24.80225  
     menunjukkan kriminalitas yang sangat kecil, harga sangat mahal
  • dibandingkan CART sebelumnya pohon ini hanya memiliki 9 node dengan *, artinya pohon jauh lebih sederhana dan lebih stabil untuk memprediksi data.

Evaluasi

Sebelum Pruning

## Evaluasi Model Sebelum Pruning (Tree besar: cp=0)

### 1) Prediksi pada data train & test
pred_train = predict(fit, newdata = train)
pred_test  = predict(fit, newdata = test)

### 2) Fungsi metrik (tanpa paket tambahan)
rmse = function(y, yhat) sqrt(mean((y - yhat)^2))
mae  = function(y, yhat) mean(abs(y - yhat))
r2   = function(y, yhat) 1 - sum((y - yhat)^2) / sum((y - mean(y))^2)

### 3) Hitung metrik
train_rmse = rmse(train$medv, pred_train)
train_mae  = mae(train$medv, pred_train)
train_r2   = r2(train$medv, pred_train)

test_rmse  = rmse(test$medv, pred_test)
test_mae   = mae(test$medv, pred_test)
test_r2    = r2(test$medv, pred_test)

data.frame(
  dataset = c("train", "test"),
  RMSE = c(train_rmse, test_rmse),
  MAE  = c(train_mae,  test_mae),
  R2   = c(train_r2,   test_r2)
)
##   dataset     RMSE      MAE        R2
## 1   train 3.868875 2.682173 0.8214813
## 2    test 4.256762 3.268909 0.7911585

Plot Diagnostik Sebelum Pruning

# Prediksi vs Aktual (Test)
plot(test$medv, pred_test,
     xlab = "Aktual (medv)", ylab = "Prediksi",
     main = "Prediksi vs Aktual (Test) - Sebelum Pruning")
abline(0, 1)

# Residual (Test)
res_test = test$medv - pred_test

hist(res_test, breaks = 30,
     main = "Distribusi Residual (Test) - Sebelum Pruning",
     xlab = "Residual (Aktual - Prediksi)")

plot(pred_test, res_test,
     xlab = "Prediksi", ylab = "Residual",
     main = "Residual vs Prediksi (Test) - Sebelum Pruning")
abline(h = 0, lty = 2)

Sesudah Pruning

## Prediksi
pred_train_pr = predict(fit_pruned, newdata = train)
pred_test_pr  = predict(fit_pruned, newdata = test)

## Metrik
rmse = function(y, yhat) sqrt(mean((y - yhat)^2))
mae  = function(y, yhat) mean(abs(y - yhat))
r2   = function(y, yhat) 1 - sum((y - yhat)^2) / sum((y - mean(y))^2)

# Hasil (pruned)
res_pruned = data.frame(
  dataset = c("train", "test"),
  RMSE = c(rmse(train$medv, pred_train_pr), rmse(test$medv, pred_test_pr)),
  MAE  = c(mae(train$medv,  pred_train_pr), mae(test$medv,  pred_test_pr)),
  R2   = c(r2(train$medv,   pred_train_pr), r2(test$medv,   pred_test_pr))
)

res_pruned
##   dataset     RMSE      MAE        R2
## 1   train 4.489985 3.308679 0.7595615
## 2    test 4.577769 3.328996 0.7584729

Plot Diagnostik Sesudah Pruning

# Prediksi vs Aktual (Test) - Pruned
plot(test$medv, pred_test_pr,
     xlab = "Aktual (medv)", ylab = "Prediksi",
     main = "Prediksi vs Aktual (Test) - Setelah Pruning")
abline(0, 1)

# Residual - Pruned (Test)
res_test_pr = test$medv - pred_test_pr

hist(res_test_pr, breaks = 30,
     main = "Distribusi Residual (Test) - Setelah Pruning",
     xlab = "Residual (Aktual - Prediksi)")

plot(pred_test_pr, res_test_pr,
     xlab = "Prediksi", ylab = "Residual",
     main = "Residual vs Prediksi (Test) - Setelah Pruning")
abline(h = 0, lty = 2)

Perbandingan

## 0) Pastikan fungsi metrik ada
rmse =  function(y, yhat) sqrt(mean((y - yhat)^2))
mae  = function(y, yhat) mean(abs(y - yhat))
r2   = function(y, yhat) 1 - sum((y - yhat)^2) / sum((y - mean(y))^2)

## 1) Prediksi BEFORE pruning (fit)
pred_train = predict(fit,  newdata = train)
pred_test  = predict(fit,  newdata = test)

res_before = data.frame(
  dataset = c("train", "test"),
  RMSE = c(rmse(train$medv, pred_train), rmse(test$medv, pred_test)),
  MAE  = c(mae(train$medv,  pred_train), mae(test$medv,  pred_test)),
  R2   = c(r2(train$medv,   pred_train), r2(test$medv,   pred_test)),
  Tahap = "before_pruning"
)

## 2) Prediksi AFTER pruning (fit_pruned)  -> pastikan fit_pruned sudah ada!
pred_train_pr = predict(fit_pruned, newdata = train)
pred_test_pr  = predict(fit_pruned, newdata = test)

res_pruned = data.frame(
  dataset = c("train", "test"),
  RMSE = c(rmse(train$medv, pred_train_pr), rmse(test$medv, pred_test_pr)),
  MAE  = c(mae(train$medv,  pred_train_pr), mae(test$medv,  pred_test_pr)),
  R2   = c(r2(train$medv,   pred_train_pr), r2(test$medv,   pred_test_pr)),
  Tahap = "after_pruning"
)

## 3) Gabungkan hasil
res_compare = dplyr::bind_rows(res_before, res_pruned) %>%
  dplyr::select(Tahap, dataset, RMSE, MAE, R2)

res_compare
##            Tahap dataset     RMSE      MAE        R2
## 1 before_pruning   train 3.868875 2.682173 0.8214813
## 2 before_pruning    test 4.256762 3.268909 0.7911585
## 3  after_pruning   train 4.489985 3.308679 0.7595615
## 4  after_pruning    test 4.577769 3.328996 0.7584729