required_pkgs = c("tidyverse", "mlbench", "rpart", "rpart.plot", "caret", "Metrics")
to_install = required_pkgs[!required_pkgs %in% installed.packages()[, "Package"]]
if (length(to_install) > 0) install.packages(to_install)
library(tidyverse)
## āā Attaching core tidyverse packages āāāāāāāāāāāāāāāāāāāāāāāā tidyverse 2.0.0 āā
## ā dplyr 1.1.4 ā readr 2.1.5
## ā forcats 1.0.0 ā stringr 1.5.1
## ā ggplot2 4.0.0 ā tibble 3.3.0
## ā lubridate 1.9.4 ā tidyr 1.3.1
## ā purrr 1.1.0
## āā Conflicts āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā tidyverse_conflicts() āā
## ā dplyr::filter() masks stats::filter()
## ā dplyr::lag() masks stats::lag()
## ā¹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(mlbench)
library(rpart)
library(rpart.plot)
library(caret)
## Loading required package: lattice
##
## Attaching package: 'caret'
## The following object is masked from 'package:purrr':
##
## lift
library(Metrics)
##
## Attaching package: 'Metrics'
## The following objects are masked from 'package:caret':
##
## precision, recall
Menggunakan package mlbench yang berisi dataset
BostonHousing karakteristik lingkungan perumahan di area
Boston, AS dengan variabel target berupa nilai numerik kontinu.
Tujuan Pemodelan: memprediksi medv
median value rumah berdasarkan fitur sosial-ekonomi dan kondisi
lingkungan
Variabel respon/ target
medv: nilai tengah (median) harga
rumah untuk suatu wilayah/area di Boston.Variabel Prediktor/ Penjelas
crim: tingkat kriminalitas per
kapita
zn: proporsi lahan hunian untuk lot besar
indus: proporsi luas untuk bisnis
non-ritel
nox: konsentrasi nitrogen oxides
rm: rata-rata jumlah kamar per rumah
age: proporsi unit hunian tua
dis: jarak ke pusat pekerjaan
rad: aksesibilitas ke jalan
raya
tax: pajak properti
ptratio: rasio muridāguru
lstat: persentase populasi
berstatus sosial lebih rendah
b: indeks yang terkait komposisi populasi (variabel
historis)
chas: indikator apakah lokasi berbatasan
dengan Charles River (0/1).
Variabel seperti ini sering diuji dalam EDA karena bisa memisahkan
distribusi target antar-kelompok (misal perbedaan medv
untuk chas=0 vs chas=1).
Dalam analisis ini saya akan mengambil peubah bebasnya adalah
chas, lstat, crim,
indus, rad, dan tax.
Dataset ini cocok untuk latihan CART Regression karena:
medv bersifat kontinu sesuai dengan pohon regresi.data("BostonHousing", package = "mlbench")
df = BostonHousing
df = df %>%
select(medv, chas, lstat, crim, indus, rad, tax)
# Atau menggunakan Base R (tanpa library tambahan)
# df = df[, c("chas", "lstat", "crim", "indus", "rad", "tax")]
glimpse(df)
## Rows: 506
## Columns: 7
## $ medv <dbl> 24.0, 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15.0ā¦
## $ chas <fct> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0ā¦
## $ lstat <dbl> 4.98, 9.14, 4.03, 2.94, 5.33, 5.21, 12.43, 19.15, 29.93, 17.10, ā¦
## $ crim <dbl> 0.00632, 0.02731, 0.02729, 0.03237, 0.06905, 0.02985, 0.08829, 0ā¦
## $ indus <dbl> 2.31, 7.07, 7.07, 2.18, 2.18, 2.18, 7.87, 7.87, 7.87, 7.87, 7.87ā¦
## $ rad <dbl> 1, 2, 2, 3, 3, 3, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4ā¦
## $ tax <dbl> 296, 242, 242, 222, 222, 222, 311, 311, 311, 311, 311, 311, 311,ā¦
summary(df$medv) # target regresi (median value)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 5.00 17.02 21.20 22.53 25.00 50.00
dim(df)
## [1] 506 7
Dataset memiliki 506 baris (observasi) dan 7 kolom (variabel)
summary(df)
## medv chas lstat crim indus
## Min. : 5.00 0:471 Min. : 1.73 Min. : 0.00632 Min. : 0.46
## 1st Qu.:17.02 1: 35 1st Qu.: 6.95 1st Qu.: 0.08205 1st Qu.: 5.19
## Median :21.20 Median :11.36 Median : 0.25651 Median : 9.69
## Mean :22.53 Mean :12.65 Mean : 3.61352 Mean :11.14
## 3rd Qu.:25.00 3rd Qu.:16.95 3rd Qu.: 3.67708 3rd Qu.:18.10
## Max. :50.00 Max. :37.97 Max. :88.97620 Max. :27.74
## rad tax
## Min. : 1.000 Min. :187.0
## 1st Qu.: 4.000 1st Qu.:279.0
## Median : 5.000 Median :330.0
## Mean : 9.549 Mean :408.2
## 3rd Qu.:24.000 3rd Qu.:666.0
## Max. :24.000 Max. :711.0
variabel chas merupakan peubah biner dengan nilai 0
(area tidak berbatasan sungai) dan 1 (area berbatasan sungai)
medv (nilai median rumah per area) memiliki nilai
min. 5 dan max. 50, artinya nilai termurah adalah sebesar 5 dan termahal
adalah sebesar 50 (satuan)
# cek missing value
colSums(is.na(df))
## medv chas lstat crim indus rad tax
## 0 0 0 0 0 0 0
Semua peubah tidak memiliki missing value, artinnya data sudah bersih.
# Histogram & Boxplot: cek distribusi variabel respon
hist(df$medv,
main = "Distribusi medv",
xlab = "medv",
col = "skyblue")
boxplot(df$medv,
horizontal = TRUE,
main = 'Boxplot medv',
col = "salmon")
Pada Histogram terlihat disrtibusi dari medv
menjulur ke kanan (right-skewes) dengan nilai minimum 5 dan maksimum
50.
Pada Boxplot terlihat median medv ada di sekitar
21-22, terlihat banyak outlier yang muncul di sisi kanan, artinya ada
sejumlah rumah yang nilainya lebih dari rata-rata.
# cek korelasi
library (corrplot)
## corrplot 0.95 loaded
num_df = df %>% dplyr::select(where(is.numeric))
corr_mat = cor(num_df)
corrplot(
corr_mat,
method = "color",
type = "upper",
order = "hclust",
addCoef.col = "black",
tl.cex = 0.8,
number.cex = 0.6
)
Hasil korelasi pearson menggunakan corrplot menunjukkan bahwa
Hubungan Positif: Tidak ada hubungan korelasi positif di antara semua peubah prediktor terhadap peubah respon.
Hubungan Negatif:
Sangat Kuat : korelasi negatif sangat kuat terjadi antara
lstat dengan medv dengan nilai 0.74.
cukup kuat: korelasi negatif cukup kuat terjadi antara
indus-medv dan
tax-medv dengan nilai -0.48 dan
-0.47.
lemah: korelasi negatif lemah terjadi antara
crim-medv dan
rad-medv dengan nilai -0.39 dan
-0.38.
# distribusi medv untuk kategori chas
boxplot(medv ~ chas,
data = df,
main = "medv per kategori chas",
xlab = "chas",
yab = "medv",
col = c("skyblue", "orange"))
Grafik Boxplot di atas menunjukkan distribusi peubah respon
medv untuk dua kategori chas. Ketika
chas = 1 artinya berbatasan dengan Charles River dan ketika
chas = 0 artinya tidak berbatasan dengan Charles River.
median pada boxplot dengan chas = 1 lebih tinggi
daripada boxplot dengan chas = 0.
artinya harga rumah yang dekat dengan sungai Charles cenderung lebih mahal.
Pada boxplot chas = 0 memiliki outliers yang cukup
banyak, sehingga walaupun rata-rata harga rumah di dekat sungai Charles
memang lebih mahal, tetapi masih ada rumah yang juga mahal di area yang
jauh dengan sungai Charles.
set.seed(123)
idx = createDataPartition(df$medv, p = 0.8, list = FALSE)
train = df[idx, ]
test = df[-idx, ]
dim(train); dim(test)
## [1] 407 7
## [1] 99 7
sebanyak 407 amatan adalah data latih.
sebanyak 99 amatan adalah data uji.
fit = rpart(
medv ~ .,
data = train,
method = "anova", # kunci untuk regresi
control = rpart.control(
minsplit = 20,
minbucket = 7,
maxdepth = 30,
cp = 0 # grow dulu
)
)
fit
## n= 407
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 407 34125.58000 22.510570
## 2) lstat>=9.63 241 5710.28100 17.429460
## 4) lstat>=16.085 120 1981.15000 14.300830
## 8) crim>=5.76921 62 625.79870 11.925810
## 16) lstat>=19.73 41 388.98880 10.868290
## 32) lstat>=26.425 15 163.90930 9.506667 *
## 33) lstat< 26.425 26 181.22460 11.653850
## 66) lstat< 21.15 7 56.84857 10.214290 *
## 67) lstat>=21.15 19 104.52530 12.184210 *
## 17) lstat< 19.73 21 101.43810 13.990480
## 34) lstat< 18.09 14 56.85714 13.485710 *
## 35) lstat>=18.09 7 33.88000 15.000000 *
## 9) crim< 5.76921 58 631.77880 16.839660
## 18) crim>=0.65402 29 178.65860 15.493100
## 36) tax< 551.5 18 33.38278 14.461110 *
## 37) tax>=551.5 11 94.73636 17.181820 *
## 19) crim< 0.65402 29 347.95450 18.186210
## 38) tax>=434.5 9 77.96222 14.855560 *
## 39) tax< 434.5 20 125.22550 19.685000
## 78) tax< 276.5 7 35.61714 18.357140 *
## 79) tax>=276.5 13 70.62000 20.400000 *
## 5) lstat< 16.085 121 1389.64400 20.532230
## 10) tax>=281.5 94 928.59320 19.908510
## 20) lstat>=11.675 65 551.31780 19.018460
## 40) lstat>=14.395 25 231.80960 17.696000
## 80) crim>=2.842095 9 113.07560 15.522220 *
## 81) crim< 2.842095 16 52.28438 18.918750 *
## 41) lstat< 14.395 40 248.45900 19.845000
## 82) tax< 417.5 23 85.85652 19.043480
## 164) lstat>=12.465 13 35.32923 18.592310 *
## 165) lstat< 12.465 10 44.44100 19.630000 *
## 83) tax>=417.5 17 127.83530 20.929410 *
## 21) lstat< 11.675 29 210.36970 21.903450
## 42) crim< 0.279605 14 45.61429 20.657140 *
## 43) crim>=0.279605 15 122.71330 23.066670 *
## 11) tax< 281.5 27 297.16960 22.703700
## 22) indus>=4.23 20 68.07000 21.550000
## 44) tax< 240 9 10.47556 20.077780 *
## 45) tax>=240 11 22.12727 22.754550 *
## 23) indus< 4.23 7 126.42000 26.000000 *
## 3) lstat< 9.63 166 13160.04000 29.887350
## 6) lstat>=5.185 111 4661.56300 26.132430
## 12) crim< 0.39079 89 1843.28000 24.802250
## 24) tax>=223.5 81 1254.95400 24.160490
## 48) lstat>=7.475 37 437.80430 22.264860
## 96) crim< 0.087825 25 282.16240 21.252000
## 192) crim>=0.047125 8 94.15500 19.825000 *
## 193) crim< 0.047125 17 164.05060 21.923530 *
## 97) crim>=0.087825 12 76.56250 24.375000 *
## 49) lstat< 7.475 44 572.38910 25.754550
## 98) indus>=4.895 25 166.53360 24.516000
## 196) rad>=4.5 13 14.40769 23.230770 *
## 197) rad< 4.5 12 107.38920 25.908330 *
## 99) indus< 4.895 19 317.04530 27.384210 *
## 25) tax< 223.5 8 217.20000 31.300000 *
## 13) crim>=0.39079 22 2023.74600 31.513640
## 26) indus>=7.17 12 1166.50900 28.391670 *
## 27) indus< 7.17 10 599.92400 35.260000 *
## 7) lstat< 5.185 55 3774.92400 37.465450
## 14) crim< 0.51938 46 2589.78400 35.523910
## 28) lstat>=4.475 19 464.08110 31.031580 *
## 29) lstat< 4.475 27 1472.43400 38.685190
## 58) tax>=248.5 19 841.69680 37.126320 *
## 59) tax< 248.5 8 474.90880 42.387500 *
## 15) crim>=0.51938 9 125.46890 47.388890 *
Output menampilkan:
n = jumlah data training yang dipakai,
split = aturan pemisah pada node,
n (di tiap node) = jumlah data yang masuk
node,
deviance ā total RSS/SSE dalam node,
yval = prediksi node (rata-rata
medv),
tanda * menunjukkan node terminal (leaf).
fit = rpart(
medv ~ .,
data = train,
method = "anova", # kunci untuk regresi
control = rpart.control(
minsplit = 20,
minbucket = 7,
maxdepth = 30,
cp = 0 # grow dulu
)
)
fit
## n= 407
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 407 34125.58000 22.510570
## 2) lstat>=9.63 241 5710.28100 17.429460
## 4) lstat>=16.085 120 1981.15000 14.300830
## 8) crim>=5.76921 62 625.79870 11.925810
## 16) lstat>=19.73 41 388.98880 10.868290
## 32) lstat>=26.425 15 163.90930 9.506667 *
## 33) lstat< 26.425 26 181.22460 11.653850
## 66) lstat< 21.15 7 56.84857 10.214290 *
## 67) lstat>=21.15 19 104.52530 12.184210 *
## 17) lstat< 19.73 21 101.43810 13.990480
## 34) lstat< 18.09 14 56.85714 13.485710 *
## 35) lstat>=18.09 7 33.88000 15.000000 *
## 9) crim< 5.76921 58 631.77880 16.839660
## 18) crim>=0.65402 29 178.65860 15.493100
## 36) tax< 551.5 18 33.38278 14.461110 *
## 37) tax>=551.5 11 94.73636 17.181820 *
## 19) crim< 0.65402 29 347.95450 18.186210
## 38) tax>=434.5 9 77.96222 14.855560 *
## 39) tax< 434.5 20 125.22550 19.685000
## 78) tax< 276.5 7 35.61714 18.357140 *
## 79) tax>=276.5 13 70.62000 20.400000 *
## 5) lstat< 16.085 121 1389.64400 20.532230
## 10) tax>=281.5 94 928.59320 19.908510
## 20) lstat>=11.675 65 551.31780 19.018460
## 40) lstat>=14.395 25 231.80960 17.696000
## 80) crim>=2.842095 9 113.07560 15.522220 *
## 81) crim< 2.842095 16 52.28438 18.918750 *
## 41) lstat< 14.395 40 248.45900 19.845000
## 82) tax< 417.5 23 85.85652 19.043480
## 164) lstat>=12.465 13 35.32923 18.592310 *
## 165) lstat< 12.465 10 44.44100 19.630000 *
## 83) tax>=417.5 17 127.83530 20.929410 *
## 21) lstat< 11.675 29 210.36970 21.903450
## 42) crim< 0.279605 14 45.61429 20.657140 *
## 43) crim>=0.279605 15 122.71330 23.066670 *
## 11) tax< 281.5 27 297.16960 22.703700
## 22) indus>=4.23 20 68.07000 21.550000
## 44) tax< 240 9 10.47556 20.077780 *
## 45) tax>=240 11 22.12727 22.754550 *
## 23) indus< 4.23 7 126.42000 26.000000 *
## 3) lstat< 9.63 166 13160.04000 29.887350
## 6) lstat>=5.185 111 4661.56300 26.132430
## 12) crim< 0.39079 89 1843.28000 24.802250
## 24) tax>=223.5 81 1254.95400 24.160490
## 48) lstat>=7.475 37 437.80430 22.264860
## 96) crim< 0.087825 25 282.16240 21.252000
## 192) crim>=0.047125 8 94.15500 19.825000 *
## 193) crim< 0.047125 17 164.05060 21.923530 *
## 97) crim>=0.087825 12 76.56250 24.375000 *
## 49) lstat< 7.475 44 572.38910 25.754550
## 98) indus>=4.895 25 166.53360 24.516000
## 196) rad>=4.5 13 14.40769 23.230770 *
## 197) rad< 4.5 12 107.38920 25.908330 *
## 99) indus< 4.895 19 317.04530 27.384210 *
## 25) tax< 223.5 8 217.20000 31.300000 *
## 13) crim>=0.39079 22 2023.74600 31.513640
## 26) indus>=7.17 12 1166.50900 28.391670 *
## 27) indus< 7.17 10 599.92400 35.260000 *
## 7) lstat< 5.185 55 3774.92400 37.465450
## 14) crim< 0.51938 46 2589.78400 35.523910
## 28) lstat>=4.475 19 464.08110 31.031580 *
## 29) lstat< 4.475 27 1472.43400 38.685190
## 58) tax>=248.5 19 841.69680 37.126320 *
## 59) tax< 248.5 8 474.90880 42.387500 *
## 15) crim>=0.51938 9 125.46890 47.388890 *
Output di atas menunjukkan hasil dari splitting di mana:
node) = nomor identitas node.
split = kondisi pemisahan.
n = jumlah observasi dalam node tersebut
deviance = RSS (Residual Sum of Squares)/JKG (Jumlah
Kuadrat Galat) semakin kecil nilainya artinya semakin seragam variance
dari node tersebut.
yval = nilai prediksi, ini adalah nilai rata-rata dari
medv dari n observasi tersebut
*denotes terminal node = terminal node, artinya tidak
ada percabangan lagi di bawahnya dan nilai tersebut merupakan nilai
prediksi akhir.
Splitting ini dilakukan sebelum pruning karena
cp = 0 , sehingga terlihat pohon terlalu ārimbunā dan
detil.
Sehingga dalam node ada yang hanya berisi 7 observasi (terlalu detil).
Namun sangat rentan terjadinya overfitting (model terlalu cocok dengan data latih, tetapi saat dilakukan data baru akan gagal).
Maka dibutuhkan pruning untuk mendapatkan hasil dan interpretasi yang lebih sederhana dan akurat.
rpart.plot(
fit,
type = 2,
extra = 1,
fallen.leaves = TRUE,
tweak = 1.0
)
sesuai dengan interpretasi sebelumnya, visualisasi pohon CART terlalu rimbun dan detil seperti pada grafik di atas. Sehingga dibutuhkan pruning.
cp_table = fit$cptable
min_xerr = min(cp_table[, "xerror"])
min_row = which.min(cp_table[, "xerror"])
xstd_min = cp_table[min_row, "xstd"]
# ambil pohon paling sederhana yang masih dalam 1 std dari minimum
best_row = which(cp_table[, "xerror"] <= (min_xerr + xstd_min))[1]
best_cp = cp_table[best_row, "CP"]
fit_pruned = prune(fit, cp = best_cp)
best_cp
## [1] 0.007540171
# Visualisasi tree hasill pruning
fit_pruned = prune(fit, cp = best_cp)
rpart.plot(
fit_pruned,
type = 2,
extra = 1,
fallen.leaves = TRUE,
tweak = 1.0
)
fit_pruned
## n= 407
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 407 34125.5800 22.51057
## 2) lstat>=9.63 241 5710.2810 17.42946
## 4) lstat>=16.085 120 1981.1500 14.30083
## 8) crim>=5.76921 62 625.7987 11.92581 *
## 9) crim< 5.76921 58 631.7788 16.83966 *
## 5) lstat< 16.085 121 1389.6440 20.53223 *
## 3) lstat< 9.63 166 13160.0400 29.88735
## 6) lstat>=5.185 111 4661.5630 26.13243
## 12) crim< 0.39079 89 1843.2800 24.80225
## 24) tax>=223.5 81 1254.9540 24.16049 *
## 25) tax< 223.5 8 217.2000 31.30000 *
## 13) crim>=0.39079 22 2023.7460 31.51364 *
## 7) lstat< 5.185 55 3774.9240 37.46545
## 14) crim< 0.51938 46 2589.7840 35.52391
## 28) lstat>=4.475 19 464.0811 31.03158 *
## 29) lstat< 4.475 27 1472.4340 38.68519 *
## 15) crim>=0.51938 9 125.4689 47.38889 *
Hasil pruning menunjukkan setelah pemangkasan terlihat hanya menghasilkan cabang/pohon yang jauh lebih sedikit.
lstat (ekonomi lingkungan) memberikan faktor yang
cukup dominan dalam penentuan harga
crim (kriminalitas) juga menjadi salah satu faktor
penentu harga baik dalam menurunkan harga yang sangat ektrem dan
menaikan harga sehingga sangat mahal
8) crim>=5.76921 62 625.7987 11.92581 *
9) crim< 5.76921 58 631.7788 16.83966 *
ketika tingkat kriminalitas tinggi, maka harganya pun lebih rendah
15) crim>=0.51938 9 125.4689 47.38889 *
12) crim< 0.39079 89 1843.2800 24.80225
menunjukkan kriminalitas yang sangat kecil, harga sangat mahaldibandingkan CART sebelumnya pohon ini hanya memiliki 9 node dengan *, artinya pohon jauh lebih sederhana dan lebih stabil untuk memprediksi data.
## Evaluasi Model Sebelum Pruning (Tree besar: cp=0)
### 1) Prediksi pada data train & test
pred_train = predict(fit, newdata = train)
pred_test = predict(fit, newdata = test)
### 2) Fungsi metrik (tanpa paket tambahan)
rmse = function(y, yhat) sqrt(mean((y - yhat)^2))
mae = function(y, yhat) mean(abs(y - yhat))
r2 = function(y, yhat) 1 - sum((y - yhat)^2) / sum((y - mean(y))^2)
### 3) Hitung metrik
train_rmse = rmse(train$medv, pred_train)
train_mae = mae(train$medv, pred_train)
train_r2 = r2(train$medv, pred_train)
test_rmse = rmse(test$medv, pred_test)
test_mae = mae(test$medv, pred_test)
test_r2 = r2(test$medv, pred_test)
data.frame(
dataset = c("train", "test"),
RMSE = c(train_rmse, test_rmse),
MAE = c(train_mae, test_mae),
R2 = c(train_r2, test_r2)
)
## dataset RMSE MAE R2
## 1 train 3.868875 2.682173 0.8214813
## 2 test 4.256762 3.268909 0.7911585
# Prediksi vs Aktual (Test)
plot(test$medv, pred_test,
xlab = "Aktual (medv)", ylab = "Prediksi",
main = "Prediksi vs Aktual (Test) - Sebelum Pruning")
abline(0, 1)
# Residual (Test)
res_test = test$medv - pred_test
hist(res_test, breaks = 30,
main = "Distribusi Residual (Test) - Sebelum Pruning",
xlab = "Residual (Aktual - Prediksi)")
plot(pred_test, res_test,
xlab = "Prediksi", ylab = "Residual",
main = "Residual vs Prediksi (Test) - Sebelum Pruning")
abline(h = 0, lty = 2)
## Prediksi
pred_train_pr = predict(fit_pruned, newdata = train)
pred_test_pr = predict(fit_pruned, newdata = test)
## Metrik
rmse = function(y, yhat) sqrt(mean((y - yhat)^2))
mae = function(y, yhat) mean(abs(y - yhat))
r2 = function(y, yhat) 1 - sum((y - yhat)^2) / sum((y - mean(y))^2)
# Hasil (pruned)
res_pruned = data.frame(
dataset = c("train", "test"),
RMSE = c(rmse(train$medv, pred_train_pr), rmse(test$medv, pred_test_pr)),
MAE = c(mae(train$medv, pred_train_pr), mae(test$medv, pred_test_pr)),
R2 = c(r2(train$medv, pred_train_pr), r2(test$medv, pred_test_pr))
)
res_pruned
## dataset RMSE MAE R2
## 1 train 4.489985 3.308679 0.7595615
## 2 test 4.577769 3.328996 0.7584729
# Prediksi vs Aktual (Test) - Pruned
plot(test$medv, pred_test_pr,
xlab = "Aktual (medv)", ylab = "Prediksi",
main = "Prediksi vs Aktual (Test) - Setelah Pruning")
abline(0, 1)
# Residual - Pruned (Test)
res_test_pr = test$medv - pred_test_pr
hist(res_test_pr, breaks = 30,
main = "Distribusi Residual (Test) - Setelah Pruning",
xlab = "Residual (Aktual - Prediksi)")
plot(pred_test_pr, res_test_pr,
xlab = "Prediksi", ylab = "Residual",
main = "Residual vs Prediksi (Test) - Setelah Pruning")
abline(h = 0, lty = 2)
## 0) Pastikan fungsi metrik ada
rmse = function(y, yhat) sqrt(mean((y - yhat)^2))
mae = function(y, yhat) mean(abs(y - yhat))
r2 = function(y, yhat) 1 - sum((y - yhat)^2) / sum((y - mean(y))^2)
## 1) Prediksi BEFORE pruning (fit)
pred_train = predict(fit, newdata = train)
pred_test = predict(fit, newdata = test)
res_before = data.frame(
dataset = c("train", "test"),
RMSE = c(rmse(train$medv, pred_train), rmse(test$medv, pred_test)),
MAE = c(mae(train$medv, pred_train), mae(test$medv, pred_test)),
R2 = c(r2(train$medv, pred_train), r2(test$medv, pred_test)),
Tahap = "before_pruning"
)
## 2) Prediksi AFTER pruning (fit_pruned) -> pastikan fit_pruned sudah ada!
pred_train_pr = predict(fit_pruned, newdata = train)
pred_test_pr = predict(fit_pruned, newdata = test)
res_pruned = data.frame(
dataset = c("train", "test"),
RMSE = c(rmse(train$medv, pred_train_pr), rmse(test$medv, pred_test_pr)),
MAE = c(mae(train$medv, pred_train_pr), mae(test$medv, pred_test_pr)),
R2 = c(r2(train$medv, pred_train_pr), r2(test$medv, pred_test_pr)),
Tahap = "after_pruning"
)
## 3) Gabungkan hasil
res_compare = dplyr::bind_rows(res_before, res_pruned) %>%
dplyr::select(Tahap, dataset, RMSE, MAE, R2)
res_compare
## Tahap dataset RMSE MAE R2
## 1 before_pruning train 3.868875 2.682173 0.8214813
## 2 before_pruning test 4.256762 3.268909 0.7911585
## 3 after_pruning train 4.489985 3.308679 0.7595615
## 4 after_pruning test 4.577769 3.328996 0.7584729