Interpretable Machine Learning

Author

Alfa Nugraha

Pendahuluan

Pembelajaran mesin (machine learning) memiliki potensi besar dalam meningkatkan kualitas sebuah produk, proses, maupun kajian penelitian. Namun, perangkat komputer yang mengadopsi model machine learning biasanya belum mampu menjelaskan hasil atau output prediksi yang dihasilkan.

Beberapa model klasik sederhana yang dapat diinterpretasikan (interpretable models) di kalangan praktisi machine learning, data saintis, statistisi, dan para peneliti yaitu regresi linear, pohon keputusan, dan aturan asosiasi sudah sangat populer digunakan. Semakin berkembangnya teknologi komputasi, berkembang pula model yang lebih kompleks yang mampu mengenali data berukuran besar dan tidak memiliki struktur dengan waktu yang singkat.

Random forest, model ensemble (gabungan), neural network dan variannya (deep learning) termasuk ke dalam model kompleks yang hasil prediksi atau keputusannya terkadang tanpa disertai penjelasan dan interpretasi lebih lanjut. Model-model ini dikenal dengan istilah model black box.

Interpretabilitas

Tidak ada definisi matematis dari interpretabilitas. Definisi (non-matematis) menurut Miller (2017)1:

  • 1 Miller, Tim. 2017. “Explanation in artificial intelligence: Insights from the social sciences.”. arXiv Preprint arXiv:1706.07269.

  • Interpretabilitas adalah sejauh mana manusia dapat memahami penyebab keputusan.

    Definisi lain adalah:

    .. sejauh mana manusia dapat secara konsisten memprediksi hasil model2.

  • 2 Kim, Been, Rajiv Khanna, and Oluwasanmi O. Koyejo. 2016. “Examples are not enough, learn to criticize! Criticism for interpretability”. Advances in Neural Information Processing Systems.

  • Semakin tinggi interpretabilitas model machine learning, semakin mudah bagi seseorang untuk memahami mengapa keputusan atau prediksi tertentu telah dibuat. Suatu model dapat diinterpretasikan lebih baik daripada model lain jika keputusannya lebih mudah dipahami manusia daripada keputusan dari model lain.

    Studi Kasus 1: Dataset perumahan di Boston

    data("Boston", package = "MASS")
    Boston
    crim zn indus chas nox rm age dis rad tax ptratio black lstat medv
    0.00632 18 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90 4.98 24.0
    0.02731 0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90 9.14 21.6
    0.02729 0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83 4.03 34.7
    0.03237 0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 394.63 2.94 33.4
    0.06905 0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 396.90 5.33 36.2
    0.02985 0 2.18 0 0.458 6.430 58.7 6.0622 3 222 18.7 394.12 5.21 28.7

    CART

    # First we fit a machine learning model on the Boston housing data
    rf <- rpart(medv ~ ., data = Boston)
    rpart.plot::rpart.plot(rf)

    Figure 1: Pohon keputusan 1
    X <- Boston[-which(names(Boston) == "medv")]
    mod <- Predictor$new(rf, data = X)

    SHAP

    # Then we explain the first instance of the dataset with the Shapley method:
    x.interest <- X[1, ]
    shapley <- Shapley$new(mod, x.interest = x.interest)
    shapley
    ## Interpretation method:  Shapley 
    ## Predicted value: 27.427273, Average prediction: 22.532806 (diff = 4.894466)
    ## 
    ## Analysed predictor: 
    ## Prediction task: unknown 
    ## 
    ## 
    ## Analysed data:
    ## Sampling from data.frame with 506 rows and 13 columns.
    ## 
    ## 
    ## Head of results:
    ##   feature       phi   phi.var feature.value
    ## 1    crim 0.6707019  3.040883  crim=0.00632
    ## 2      zn 0.0000000  0.000000         zn=18
    ## 3   indus 0.0000000  0.000000    indus=2.31
    ## 4    chas 0.0000000  0.000000        chas=0
    ## 5     nox 0.0000000  0.000000     nox=0.538
    ## 6      rm 0.5648295 46.013896      rm=6.575
    # Look at the results in a table
    kable(shapley$results)
    feature phi phi.var feature.value
    crim 0.6707019 3.040884 crim=0.00632
    zn 0.0000000 0.000000 zn=18
    indus 0.0000000 0.000000 indus=2.31
    chas 0.0000000 0.000000 chas=0
    nox 0.0000000 0.000000 nox=0.538
    rm 0.5648295 46.013896 rm=6.575
    age 0.0000000 0.000000 age=65.2
    dis -0.4806171 5.852113 dis=4.09
    rad 0.0000000 0.000000 rad=1
    tax 0.0000000 0.000000 tax=296
    ptratio 0.0000000 0.000000 ptratio=15.3
    black 0.0000000 0.000000 black=396.9
    lstat 3.2864649 25.089114 lstat=4.98
    # Or as a plot
    plot(shapley)

    # Explain another instance
    shapley$explain(X[2, ])
    plot(shapley)

    Studi Kasus 2: Dataset Iris

    Ilustrasi berikut digunakan untuk permasalahan multikelas

    CART

    rf_iris <- rpart(Species ~ ., data = iris)
    rpart.plot::rpart.plot(rf_iris)

    Figure 2: Pohon keputusan 2
    X <- iris[-which(names(iris) == "Species")]
    mod <- Predictor$new(rf_iris, data = X, type = "prob")

    SHAP

    # Then we explain the first instance of the dataset with the Shapley() method:
    shapley <- Shapley$new(mod, x.interest = X[1, ])
    kable(shapley$results)
    feature class phi phi.var feature.value
    Sepal.Length setosa 0.0000000 0.0000000 Sepal.Length=5.1
    Sepal.Width setosa 0.0000000 0.0000000 Sepal.Width=3.5
    Petal.Length setosa 0.6400000 0.2327273 Petal.Length=1.4
    Petal.Width setosa 0.0000000 0.0000000 Petal.Width=0.2
    Sepal.Length versicolor 0.0000000 0.0000000 Sepal.Length=5.1
    Sepal.Width versicolor 0.0000000 0.0000000 Sepal.Width=3.5
    Petal.Length versicolor -0.5276006 0.2012435 Petal.Length=1.4
    Petal.Width versicolor 0.1505636 0.1117980 Petal.Width=0.2
    Sepal.Length virginica 0.0000000 0.0000000 Sepal.Length=5.1
    Sepal.Width virginica 0.0000000 0.0000000 Sepal.Width=3.5
    Petal.Length virginica -0.1123994 0.0502612 Petal.Length=1.4
    Petal.Width virginica -0.1505636 0.1117980 Petal.Width=0.2
    plot(shapley)

    # You can also focus on one class
    mod <- Predictor$new(rf_iris, data = X, type = "prob", class = "setosa")
    shapley <- Shapley$new(mod, x.interest = X[1, ])
    shapley$results
    ##        feature  phi   phi.var    feature.value
    ## 1 Sepal.Length 0.00 0.0000000 Sepal.Length=5.1
    ## 2  Sepal.Width 0.00 0.0000000  Sepal.Width=3.5
    ## 3 Petal.Length 0.78 0.1733333 Petal.Length=1.4
    ## 4  Petal.Width 0.00 0.0000000  Petal.Width=0.2
    plot(shapley)

    Studi Kasus 3: Pinjaman Bank

    Suatu perusahaan perbankan meneliti 75 jenis skema pinjaman yang telah diberi rating oleh para customernya pada loan.csv berikut

    # Read the Data
    loan <- read.csv("loan.csv", header=T)
    loan
    besar.pinjaman lama.pembayaran bunga pembayaran.per.bulan banyak.cash.back rating
    70 4 1 130 10.0 68.40297
    120 3 5 15 2.0 33.98368
    70 4 1 260 9.0 59.42551
    50 4 0 140 14.0 93.70491
    110 2 2 180 1.5 29.50954
    110 2 0 125 1.0 33.17409

    Peubah yang digunakan ialah:

    • Besar pinjaman (dalam juta rupiah)
    • Lama pembayaran (dalam tahun)
    • Tambahan bunga yang ditetapkan (dalam %)
    • Pembayaran per bulan (dalam 10000)
    • Banyak cash back yang diterapkan pada skema tersebut

    Tujuan penelitian yang dilakukan ialah memprediksi rating skema pinjaman berdasarkan variabel-variabel tersebut. Manakah variabel yang paling penting?

    Persiapan data

    # Random sampling
    samplesize <- 0.60 * nrow(loan)
    
    set.seed(80)
    index <- sample( seq_len(nrow(loan)), size = samplesize )
    # Create training and test set
    datatrain <- loan[ index, ]
    datatest <- loan[ -index, ]

    Penskalaan

    Langkah pertama adalah penskalaan gugus data. Hal ini penting karena jika tidak, suatu peubah mungkin mempunyai dampak besar pada peubah hasil prediksi hanya karena skalanya. Terkadang peubah yang belum dilakukan proses scaling cenderung menghasilkan hasil yang tidak memiliki makna.

    ## Scale data for neural network
    max <- apply(loan , 2 , max)
    min <- apply(loan, 2 , min)
    scaled <- as.data.frame(scale(loan, center = min, scale = max - min))
    kable(head(scaled))
    besar.pinjaman lama.pembayaran bunga pembayaran.per.bulan banyak.cash.back rating
    0.1818182 0.6 0.2 0.406250 0.7142857 0.6655928
    0.6363636 0.4 1.0 0.046875 0.1428571 0.2106846
    0.1818182 0.6 0.2 0.812500 0.6428571 0.5469406
    0.0000000 0.6 0.0 0.437500 1.0000000 1.0000000
    0.5454545 0.2 0.4 0.562500 0.1071429 0.1515514
    0.5454545 0.2 0.0 0.390625 0.0714286 0.1999845

    ANN dengan neuralnet

    Selain terdapat di dalam paket keras dan tensorflow, komunitas R juga mengembangkan model jaringan syaraf tiruan (ANN) pada paket neuralnet.

    Partisi data

    # creating training and test set
    trainNN <- scaled[index , ]
    testNN <- scaled[-index , ]

    Membangun model

    # fit neural network
    set.seed(2)
    NN <- neuralnet(
      rating ~ besar.pinjaman + lama.pembayaran + bunga + pembayaran.per.bulan + banyak.cash.back, trainNN, 
      hidden = 3 , 
      linear.output = T
    )

    Visualisasi ANN

    # plot neural network
    plot(NN)

    Prediksi

    ## Prediction using neural network
    predict_testNN <- compute(NN, testNN[,c(1:5)])
    predict_testNN <- (predict_testNN$net.result * (max(loan$rating) - min(loan$rating))) + min(loan$rating)
    plot(datatest$rating, predict_testNN, col='blue', pch=16, ylab = "predicted rating NN", xlab = "real rating")
    
    abline(0,1)

    # Calculate Root Mean Square Error (RMSE)
    RMSE.NN <- (sum((datatest$rating - predict_testNN)^2) / nrow(datatest)) ^ 0.5
    cat("Metrik RMSE: ", RMSE.NN)
    ## Metrik RMSE:  6.945017

    Feature importance

    X <- loan[which(names(loan) != "rating")]
    predictor <- Predictor$new(NN, data = X, y = loan$rating)
    
    imp <- FeatureImp$new(predictor, loss = "mae")
    ## Warning: package 'lubridate' was built under R version 4.2.2
    ## Warning: package 'timechange' was built under R version 4.2.2
    ## Warning: package 'cli' was built under R version 4.2.3
    ## Warning: package 'stringr' was built under R version 4.2.3
    ## Warning: package 'digest' was built under R version 4.2.2
    ## Warning: package 'htmltools' was built under R version 4.2.2
    ## Warning: package 'htmlwidgets' was built under R version 4.2.2
    ## Warning: package 'rlang' was built under R version 4.2.3
    ## Warning: package 'rstudioapi' was built under R version 4.2.2
    ## Warning: package 'jsonlite' was built under R version 4.2.2
    ## Warning: package 'prediction' was built under R version 4.2.3
    ## Warning: package 'Metrics' was built under R version 4.2.3
    ## Warning: package 'hms' was built under R version 4.2.2
    ## Warning: package 'evaluate' was built under R version 4.2.2
    ## Warning: package 'vctrs' was built under R version 4.2.3
    ## Warning: package 'foreach' was built under R version 4.2.3
    ## Warning: package 'purrr' was built under R version 4.2.3
    ## Warning: package 'xfun' was built under R version 4.2.2
    plot(imp)

    imp$results
    ##                feature importance.05 importance importance.95 permutation.error
    ## 1       besar.pinjaman     0.9993735  1.0006164     1.0011689          42.45750
    ## 2     banyak.cash.back     0.9996252  1.0001133     1.0006106          42.43615
    ## 3      lama.pembayaran     0.9992501  0.9997796     1.0004272          42.42199
    ## 4                bunga     0.9994564  0.9996783     0.9999468          42.41770
    ## 5 pembayaran.per.bulan     0.9976793  0.9982261     1.0000772          42.35607

    SHAP

    shapley <- Shapley$new(predictor, x.interest = X[1, ])
    shapley$plot()

    Studi Kasus 4: Prediksi Gejala Stroke

    According to the World Health Organization (WHO) stroke is the 2nd leading cause of death globally, responsible for approximately 11% of total deaths.3

    This dataset is used to predict whether a patient is likely to get stroke based on the input parameters like gender, age, various diseases, and smoking status. Each row in the data provides relavant information about the patient.4

    Attribute Information:

    1. id: unique identifier
    2. gender: “Male”, “Female” or “Other”
    3. age: age of the patient
    4. hypertension: 0 if the patient doesn’t have hypertension, 1 if the patient has hypertension
    5. heart_disease: 0 if the patient doesn’t have any heart diseases, 1 if the patient has a heart disease
    6. ever_married: “No” or “Yes”
    7. work_type: “children”, “Govt_jov”, “Never_worked”, “Private” or “Self-employed”
    8. Residence_type: “Rural” or “Urban”
    9. avg_glucose_level: average glucose level in blood
    10. bmi: body mass index
    11. smoking_status: “formerly smoked”, “never smoked”, “smokes” or “Unknown”*
    12. stroke: 1 if the patient had a stroke or 0 if not

    Note: Unknown in smoking_status means that the information is unavailable for this patient

    stroke <- read_csv("healthcare-dataset-stroke-data.csv")  %>%
              mutate(hypertension = recode_factor(hypertension, `1` = "Yes", `0` = "No"),
                   heart_disease = recode_factor(heart_disease, `1` = "Yes", `0` = "No"),
                   stroke = factor(stroke),
                   bmi = as.numeric(as.character(bmi))) %>%
              mutate_if(is.character, as.factor) 
    ## Rows: 5110 Columns: 12
    ## ── Column specification ────────────────────────────────────────────────────────
    ## Delimiter: ","
    ## chr (6): gender, ever_married, work_type, Residence_type, bmi, smoking_status
    ## dbl (6): id, age, hypertension, heart_disease, avg_glucose_level, stroke
    ## 
    ## ℹ Use `spec()` to retrieve the full column specification for this data.
    ## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
    ## Warning in eval(cols[[col]], .data, parent.frame()): NAs introduced by coercion
    stroke$id <- NULL

    Note: We will use the DALEX package for Explanatory Model Analysis (EMA). To use some of the functions in this package, categorical predictor variables need to be converted to factors.

    skim(stroke)
    Data summary
    Name stroke
    Number of rows 5110
    Number of columns 11
    _______________________
    Column type frequency:
    factor 8
    numeric 3
    ________________________
    Group variables None

    Variable type: factor

    skim_variable n_missing complete_rate ordered n_unique top_counts
    gender 0 1 FALSE 3 Fem: 2994, Mal: 2115, Oth: 1
    hypertension 0 1 FALSE 2 No: 4612, Yes: 498
    heart_disease 0 1 FALSE 2 No: 4834, Yes: 276
    ever_married 0 1 FALSE 2 Yes: 3353, No: 1757
    work_type 0 1 FALSE 5 Pri: 2925, Sel: 819, chi: 687, Gov: 657
    Residence_type 0 1 FALSE 2 Urb: 2596, Rur: 2514
    smoking_status 0 1 FALSE 4 nev: 1892, Unk: 1544, for: 885, smo: 789
    stroke 0 1 FALSE 2 0: 4861, 1: 249

    Variable type: numeric

    skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
    age 0 1.00 43.23 22.61 0.08 25.00 45.00 61.00 82.00 ▅▆▇▇▆
    avg_glucose_level 0 1.00 106.15 45.28 55.12 77.24 91.88 114.09 271.74 ▇▃▁▁▁
    bmi 201 0.96 28.89 7.85 10.30 23.50 28.10 33.10 97.60 ▇▇▁▁▁
    kable(head(stroke))
    gender age hypertension heart_disease ever_married work_type Residence_type avg_glucose_level bmi smoking_status stroke
    Male 67 No Yes Yes Private Urban 228.69 36.6 formerly smoked 1
    Female 61 No No Yes Self-employed Rural 202.21 NA never smoked 1
    Male 80 No Yes Yes Private Rural 105.92 32.5 never smoked 1
    Female 49 No No Yes Private Urban 171.23 34.4 smokes 1
    Female 79 Yes No Yes Self-employed Rural 174.12 24.0 never smoked 1
    Male 81 No No Yes Private Urban 186.21 29.0 formerly smoked 1

    Missing data

    Over 30% missing observations for smoking_status, also a few for bmi (< 5%). Omit rows with missing data. The observations with missing values were also omitted in the previously mentioned article published on arxiv.org.

    stroke <- na.omit(stroke)

    Data tidak seimbang

    Ada baiknya perlu mengetahui berapa banyak pasien yang statusnya tergolong stroke (1) dan tidak (0), karena variabel ini merupakan variabel yang akan kita gunakan dalam pemodelan sebagai variabel target.

    table(stroke$stroke)
    ## 
    ##    0    1 
    ## 4700  209
    prop.table(table(stroke$stroke))
    ## 
    ##          0          1 
    ## 0.95742514 0.04257486

    Berdasarkan hasil di atas diperoleh bahwa yang stroke ada sebanyak 4861 pasien atau 95% dari keseluruhan yang ada.

    barplot(prop.table(table(stroke$stroke, stroke$gender), margin=2), col=c(2, 4))

    The dataset is highly unbalanced. Only 1.9% of the people in the dataset suffer from stroke condition. This poses a difficult problem in training a decision tree (to be exact in any machine-learning based model).

    tree <- rpart(stroke ~ ., data = stroke, method = "class")
    tree
    ## n= 4909 
    ## 
    ## node), split, n, loss, yval, (yprob)
    ##       * denotes terminal node
    ## 
    ## 1) root 4909 209 0 (0.95742514 0.04257486) *

    The tree simply predicts stroke = “No” for every observation in the dataset.

    Downsample data

    We employ a random downsampling technique to reduce the adverse effect of unbalanced dataset.

    set.seed(123)
    minority_obs <- stroke %>% filter(stroke == 1)
    majority_obs <- stroke %>% filter(stroke == 0) %>% sample_n(nrow(minority_obs))
    balanced_data <- bind_rows(minority_obs, majority_obs)
    prop.table(table(balanced_data$stroke))
    ## 
    ##   0   1 
    ## 0.5 0.5

    Partisi data

    set.seed(123)
    training_samples <- as.vector(caret::createDataPartition(balanced_data$stroke, p = 0.7, list = FALSE))
    train_data <- balanced_data[ training_samples, ]
    test_data  <- balanced_data[-training_samples, ]

    Membangun model

    CART

    cart <- rpart(stroke ~ ., data = train_data, method = "class")
    rpart.plot(cart, extra = 4)

    Evaluasi model

    Evaluate model performance on the test dataset.

    caret::confusionMatrix(predict(cart, test_data, type = "class"), as.factor(test_data$stroke), positive = "1", mode = "prec_recall")
    ## Confusion Matrix and Statistics
    ## 
    ##           Reference
    ## Prediction  0  1
    ##          0 39  9
    ##          1 23 53
    ##                                           
    ##                Accuracy : 0.7419          
    ##                  95% CI : (0.6557, 0.8163)
    ##     No Information Rate : 0.5             
    ##     P-Value [Acc > NIR] : 3.264e-08       
    ##                                           
    ##                   Kappa : 0.4839          
    ##                                           
    ##  Mcnemar's Test P-Value : 0.02156         
    ##                                           
    ##               Precision : 0.6974          
    ##                  Recall : 0.8548          
    ##                      F1 : 0.7681          
    ##              Prevalence : 0.5000          
    ##          Detection Rate : 0.4274          
    ##    Detection Prevalence : 0.6129          
    ##       Balanced Accuracy : 0.7419          
    ##                                           
    ##        'Positive' Class : 1               
    ## 

    Random Forest

    trControl <- trainControl(method = "cv", number = 10, search = "grid")
    
    set.seed(1234)
    rf_mtry<- caret::train(as.factor(stroke) ~ .,
        data = train_data,
        method = "rf", 
        metric = "Accuracy",
        trControl = trControl)
        
    best_mtry <- rf_mtry$bestTune$mtry 
    best_mtry
    ## [1] 9
    max(rf_mtry$results$Accuracy)
    ## [1] 0.7850575
    store_maxnode <- list()
    tuneGrid <- expand.grid(.mtry = best_mtry)
    for (maxnodes in c(5: 15)) {
        set.seed(1234)
        rf_maxnode <- caret::train(as.factor(stroke) ~ .,
            data = train_data,
            method = "rf",
            metric = "Accuracy",
            tuneGrid = tuneGrid,
            trControl = trControl,
            importance = TRUE,
            nodesize = 14,
            maxnodes = maxnodes,
            ntree = 300)
        current_iteration <- toString(maxnodes)
        store_maxnode[[current_iteration]] <- rf_maxnode
    }
    results_mtry <- resamples(store_maxnode)
    summary(results_mtry)
    ## 
    ## Call:
    ## summary.resamples(object = results_mtry)
    ## 
    ## Models: 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 
    ## Number of resamples: 10 
    ## 
    ## Accuracy 
    ##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
    ## 5  0.6551724 0.7347701 0.7931034 0.7783908 0.8318966 0.8333333    0
    ## 6  0.6896552 0.7606322 0.8103448 0.7887356 0.8318966 0.8333333    0
    ## 7  0.7241379 0.7347701 0.8103448 0.7920690 0.8318966 0.8666667    0
    ## 8  0.7241379 0.7347701 0.8103448 0.7920690 0.8318966 0.8666667    0
    ## 9  0.7241379 0.7347701 0.8103448 0.7954023 0.8318966 0.8666667    0
    ## 10 0.5862069 0.7413793 0.8000000 0.7781609 0.8275862 0.8666667    0
    ## 11 0.6206897 0.7672414 0.8000000 0.7816092 0.8275862 0.8666667    0
    ## 12 0.5862069 0.7327586 0.7833333 0.7645977 0.8275862 0.8666667    0
    ## 13 0.5862069 0.7586207 0.7666667 0.7647126 0.8189655 0.8666667    0
    ## 14 0.5517241 0.7586207 0.7798851 0.7611494 0.8206897 0.8666667    0
    ## 15 0.5517241 0.7586207 0.7798851 0.7577011 0.8206897 0.8666667    0
    ## 
    ## Kappa 
    ##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
    ## 5  0.3028846 0.4767913 0.5827314 0.5569665 0.6643026 0.6666667    0
    ## 6  0.3741007 0.5267760 0.6180051 0.5778075 0.6643026 0.6666667    0
    ## 7  0.4449761 0.4748826 0.6180051 0.5846426 0.6643026 0.7333333    0
    ## 8  0.4449761 0.4748826 0.6180051 0.5846426 0.6643026 0.7333333    0
    ## 9  0.4449761 0.4748826 0.6180051 0.5913093 0.6643026 0.7333333    0
    ## 10 0.1753555 0.4855557 0.6000000 0.5568120 0.6559773 0.7333333    0
    ## 11 0.2458629 0.5360039 0.6000000 0.5635369 0.6559773 0.7333333    0
    ## 12 0.1753555 0.4685009 0.5666667 0.5296044 0.6559773 0.7333333    0
    ## 13 0.1753555 0.5166585 0.5333333 0.5297633 0.6361281 0.7333333    0
    ## 14 0.1045131 0.5166585 0.5605055 0.5226206 0.6392086 0.7333333    0
    ## 15 0.1045131 0.5166585 0.5605055 0.5159273 0.6392086 0.7333333    0
    store_maxtrees <- list()
    for (ntree in c(250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000)) {
        set.seed(5678)
        rf_maxtrees <- caret::train(as.factor(stroke) ~ .,
            data = train_data,
            method = "rf",
            metric = "Accuracy",
            tuneGrid = tuneGrid,
            trControl = trControl,
            importance = TRUE,
            nodesize = 14,
            maxnodes = 14,
            ntree = ntree)
        key <- toString(ntree)
        store_maxtrees[[key]] <- rf_maxtrees
    }
    results_tree <- resamples(store_maxtrees)
    summary(results_tree)
    ## 
    ## Call:
    ## summary.resamples(object = results_tree)
    ## 
    ## Models: 250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000 
    ## Number of resamples: 10 
    ## 
    ## Accuracy 
    ##           Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
    ## 250  0.6785714 0.7396552 0.7965517 0.7953284 0.8534483 0.9000000    0
    ## 300  0.6785714 0.7396552 0.7931034 0.7884319 0.8206897 0.9000000    0
    ## 350  0.6785714 0.7396552 0.7931034 0.7952135 0.8465517 0.9333333    0
    ## 400  0.6785714 0.7396552 0.7931034 0.7952135 0.8465517 0.9333333    0
    ## 450  0.6428571 0.7396552 0.7931034 0.7848604 0.8206897 0.9000000    0
    ## 500  0.6428571 0.7396552 0.7931034 0.7848604 0.8206897 0.9000000    0
    ## 550  0.6428571 0.7396552 0.7931034 0.7881938 0.8206897 0.9333333    0
    ## 600  0.6428571 0.7396552 0.7931034 0.7881938 0.8206897 0.9333333    0
    ## 800  0.6785714 0.7396552 0.7931034 0.7850985 0.8206897 0.9000000    0
    ## 1000 0.6785714 0.7396552 0.7931034 0.7884319 0.8206897 0.9000000    0
    ## 2000 0.6785714 0.7396552 0.7931034 0.7884319 0.8206897 0.9333333    0
    ## 
    ## Kappa 
    ##           Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
    ## 250  0.3571429 0.4811475 0.5928571 0.5911511 0.7069084 0.8000000    0
    ## 300  0.3571429 0.4811475 0.5837225 0.5773055 0.6429078 0.8000000    0
    ## 350  0.3571429 0.4811475 0.5837225 0.5907630 0.6938389 0.8666667    0
    ## 400  0.3571429 0.4811475 0.5837225 0.5907630 0.6938389 0.8666667    0
    ## 450  0.2857143 0.4811475 0.5837225 0.5701627 0.6429078 0.8000000    0
    ## 500  0.2857143 0.4811475 0.5837225 0.5701627 0.6429078 0.8000000    0
    ## 550  0.2857143 0.4811475 0.5837225 0.5768293 0.6429078 0.8666667    0
    ## 600  0.2857143 0.4811475 0.5837225 0.5768293 0.6429078 0.8666667    0
    ## 800  0.3571429 0.4811475 0.5837225 0.5706388 0.6429078 0.8000000    0
    ## 1000 0.3571429 0.4811475 0.5837225 0.5773055 0.6429078 0.8000000    0
    ## 2000 0.3571429 0.4811475 0.5837225 0.5773055 0.6429078 0.8666667    0

    Plot kalibrasi

    set.seed(1234)
    model_forest <- caret::train(as.factor(stroke) ~ ., data = train_data, method = "rf", metric = "Accuracy", mrty = 2, ntree = 300, maxnodes = 14)
    pred <- predict(model_forest, test_data, type = "prob")
    prob <- pred[, "1"]
    calCurve <- caret::calibration(as.factor(test_data$stroke) ~ prob, data = test_data, class = '1')
    calCurve
    ## 
    ## Call:
    ## calibration.formula(x = as.factor(test_data$stroke) ~ prob, data =
    ##  test_data, class = "1")
    ## 
    ## Models: prob 
    ## Event:  1 
    ## Cuts: 11

    Explanatory Model Analysis using DALEX

    model_tree <- rpart(stroke ~ ., data = train_data, method = "class")
    exp_tree <- DALEX::explain(model_tree, data = train_data[,-11], y = as.numeric(train_data$stroke)-1, label = "Decision Tree", type = "classification")
    ## Preparation of a new explainer is initiated
    ##   -> model label       :  Decision Tree 
    ##   -> data              :  294  rows  10  cols 
    ##   -> data              :  tibble converted into a data.frame 
    ##   -> target variable   :  294  values 
    ##   -> predict function  :  yhat.rpart  will be used (  default  )
    ##   -> predicted values  :  No value for predict function target column. (  default  )
    ##   -> model_info        :  package rpart , ver. 4.1.16 , task classification (  default  ) 
    ##   -> model_info        :  type set to  classification 
    ##   -> predicted values  :  numerical, min =  0.05208333 , mean =  0.5 , max =  0.9  
    ##   -> residual function :  difference between y and yhat (  default  )
    ##   -> residuals         :  numerical, min =  -0.9 , mean =  7.983914e-18 , max =  0.9479167  
    ##   A new explainer has been created!
    exp_tree_updated <- update_data(exp_tree, data = test_data[,-11], y = as.numeric(test_data$stroke)-1)
    ##   -> data              :  124  rows  10  cols 
    ##   -> target variable   :  124  values 
    ##   An explainer has been updated! 
    model_forest <- caret::train(as.factor(stroke) ~ ., data = train_data, method = "rf", metric = "Accuracy", mrty = 2, ntree = 300, maxnodes = 14)
    exp_forest <- DALEX::explain(model_forest, data = train_data[,-11], y = as.numeric(train_data$stroke)-1, label = "Random forest", type = "classification")
    ## Preparation of a new explainer is initiated
    ##   -> model label       :  Random forest 
    ##   -> data              :  294  rows  10  cols 
    ##   -> data              :  tibble converted into a data.frame 
    ##   -> target variable   :  294  values 
    ##   -> predict function  :  yhat.train  will be used (  default  )
    ##   -> predicted values  :  No value for predict function target column. (  default  )
    ##   -> model_info        :  package caret , ver. 6.0.93 , task classification (  default  ) 
    ##   -> model_info        :  type set to  classification 
    ##   -> predicted values  :  numerical, min =  0 , mean =  0.5494898 , max =  1  
    ##   -> residual function :  difference between y and yhat (  default  )
    ##   -> residuals         :  numerical, min =  -0.9133333 , mean =  -0.0494898 , max =  0.6533333  
    ##   A new explainer has been created!
    exp_forest_updated <- update_data(exp_forest, data = test_data[,-11], y = as.numeric(test_data$stroke)-1)
    ##   -> data              :  124  rows  10  cols 
    ##   -> target variable   :  124  values 
    ##   An explainer has been updated! 

    Kinerja model

    CART

    perf_tree <- model_performance(exp_tree_updated)
    perf_tree
    ## Measures for:  classification
    ## recall     : 0.8548387 
    ## precision  : 0.6973684 
    ## f1         : 0.7681159 
    ## accuracy   : 0.7419355 
    ## auc        : 0.7814776
    ## 
    ## Residuals:
    ##          0%         10%         20%         30%         40%         50% 
    ## -0.90000000 -0.80900749 -0.42307692 -0.05208333 -0.05208333  0.02395833 
    ##         60%         70%         80%         90%        100% 
    ##  0.16831683  0.16831683  0.16831683  0.24390244  0.94791667

    Random Forest

    perf_forest <- model_performance(exp_forest_updated)
    perf_forest
    ## Measures for:  classification
    ## recall     : 0.9032258 
    ## precision  : 0.7179487 
    ## f1         : 0.8 
    ## accuracy   : 0.7741935 
    ## auc        : 0.8433923
    ## 
    ## Residuals:
    ##           0%          10%          20%          30%          40%          50% 
    ## -0.993333333 -0.736000000 -0.455333333 -0.041000000 -0.006000000  0.001666667 
    ##          60%          70%          80%          90%         100% 
    ##  0.039333333  0.086666667  0.213333333  0.359666667  0.970000000

    Grafik

    plot(perf_tree, perf_forest, geom = "roc")

    plot(perf_tree, perf_forest, geom = "boxplot")

    Feature attribution - prediction level

    henry <- test_data[1,]
    predict(exp_tree, henry)
    ## [1] 0.7560976
    predict(exp_forest, henry)
    ## [1] 0.9866667
    henry$stroke
    ## [1] 1
    ## Levels: 0 1
    sh_forest <- predict_parts(exp_forest, henry, type = "shap", B = 1)
    plot(sh_forest, show_boxplots = FALSE) + 
       ggtitle("Shapley values for Henry","")

    bd_forest <- predict_parts(exp_forest, henry, type = "break_down_interactions")
    bd_forest
    ##                                           contribution
    ## Random forest: intercept                         0.549
    ## Random forest: age = 67                          0.224
    ## Random forest: avg_glucose_level = 228.69        0.051
    ## Random forest: bmi = 36.6                        0.115
    ## Random forest: heart_disease = 1                 0.015
    ## Random forest: ever_married = 2                  0.009
    ## Random forest: hypertension = 2                 -0.001
    ## Random forest: smoking_status = 1                0.023
    ## Random forest: Residence_type = 2                0.000
    ## Random forest: work_type = 4                     0.001
    ## Random forest: gender = 2                        0.000
    ## Random forest: prediction                        0.987
    plot(bd_forest, show_boxplots = FALSE) + 
      ggtitle("Break down values for Henry","") # + 

     # scale_y_continuous("",limits = c(0.09,0.33))

    Feature importance - model level

    Visualize feature importance for the decision tree and random forest model.

    mp_tree <- model_parts(exp_tree, type = "difference")
    mp_forest<- model_parts(exp_forest, type = "difference")
    plot(mp_tree, mp_forest, show_boxplots = FALSE)

    Centeris-paribus profiles (What if?)

    Visualize how the model response would change for Henry if one of the coordinates in his observation was changed while leaving all other coordinates unchanged.

    cp_forest <- predict_profile(exp_forest, henry)
    cp_forest
    ## Top profiles    : 
    ##   gender     age hypertension heart_disease ever_married work_type
    ## 1 Female 67.0000           No           Yes          Yes   Private
    ## 2   Male 67.0000           No           Yes          Yes   Private
    ## 3   Male  0.4800           No           Yes          Yes   Private
    ## 4   Male  1.2952           No           Yes          Yes   Private
    ## 5   Male  2.1104           No           Yes          Yes   Private
    ## 6   Male  2.9256           No           Yes          Yes   Private
    ##   Residence_type avg_glucose_level  bmi  smoking_status    _yhat_ _vname_ _ids_
    ## 1          Urban            228.69 36.6 formerly smoked 0.9866667  gender     1
    ## 2          Urban            228.69 36.6 formerly smoked 0.9866667  gender     1
    ## 3          Urban            228.69 36.6 formerly smoked 0.1133333     age     1
    ## 4          Urban            228.69 36.6 formerly smoked 0.1133333     age     1
    ## 5          Urban            228.69 36.6 formerly smoked 0.1133333     age     1
    ## 6          Urban            228.69 36.6 formerly smoked 0.1133333     age     1
    ##         _label_
    ## 1 Random forest
    ## 2 Random forest
    ## 3 Random forest
    ## 4 Random forest
    ## 5 Random forest
    ## 6 Random forest
    ## 
    ## 
    ## Top observations:
    ##   gender age hypertension heart_disease ever_married work_type Residence_type
    ## 1   Male  67           No           Yes          Yes   Private          Urban
    ##   avg_glucose_level  bmi  smoking_status    _yhat_       _label_ _ids_
    ## 1            228.69 36.6 formerly smoked 0.9866667 Random forest     1
    plot(cp_forest, variables = c("age", "avg_glucose_level"))

    Partial dependence profiles

    Create a partial dependence profile for age for the random forest model. Partial dependence profiles are averages from CP profiles for all observations.

    mp_forest <- model_profile(exp_forest)
    plot(mp_forest, variables = "age")