rm(list = ls())
library(h2o)
## 
## ----------------------------------------------------------------------
## 
## Your next step is to start H2O:
##     > h2o.init()
## 
## For H2O package documentation, ask for help:
##     > ??h2o
## 
## After starting H2O, you can use the Web UI at http://localhost:54321
## For more information visit https://docs.h2o.ai
## 
## ----------------------------------------------------------------------
## 
## Attaching package: 'h2o'
## The following objects are masked from 'package:stats':
## 
##     cor, sd, var
## The following objects are masked from 'package:base':
## 
##     &&, %*%, %in%, ||, apply, as.factor, as.numeric, colnames,
##     colnames<-, ifelse, is.character, is.factor, is.numeric, log,
##     log10, log1p, log2, round, signif, trunc
library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.2
## ──
## ✔ ggplot2 3.4.2     ✔ purrr   1.0.1
## ✔ tibble  3.2.1     ✔ dplyr   1.1.2
## ✔ tidyr   1.3.0     ✔ stringr 1.5.0
## ✔ readr   2.1.4     ✔ forcats 0.5.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
library(splitstackshape)
library(pROC)
## Type 'citation("pROC")' for a citation.
## 
## Attaching package: 'pROC'
## 
## The following object is masked from 'package:h2o':
## 
##     var
## 
## The following objects are masked from 'package:stats':
## 
##     cov, smooth, var
#h2o.removeAll()
# Initialize an h2o cluster
h2o.init()
## 
## H2O is not running yet, starting it now...
## 
## Note:  In case of errors look at the following log files:
##     /var/folders/6b/6l49pfsn3zsgkf83sll27lk0sctcj6/T//RtmpETWrrv/file107e63b97f793/h2o_xut2_started_from_r.out
##     /var/folders/6b/6l49pfsn3zsgkf83sll27lk0sctcj6/T//RtmpETWrrv/file107e675fef275/h2o_xut2_started_from_r.err
## 
## 
## Starting H2O JVM and connecting: ........ Connection successful!
## 
## R is connected to the H2O cluster: 
##     H2O cluster uptime:         7 seconds 919 milliseconds 
##     H2O cluster timezone:       America/New_York 
##     H2O data parsing timezone:  UTC 
##     H2O cluster version:        3.40.0.1 
##     H2O cluster version age:    2 months and 14 days 
##     H2O cluster name:           H2O_started_from_R_xut2_fno073 
##     H2O cluster total nodes:    1 
##     H2O cluster total memory:   7.09 GB 
##     H2O cluster total cores:    12 
##     H2O cluster allowed cores:  12 
##     H2O cluster healthy:        TRUE 
##     H2O Connection ip:          localhost 
##     H2O Connection port:        54321 
##     H2O Connection proxy:       NA 
##     H2O Internal Security:      FALSE 
##     R Version:                  R version 4.2.0 (2022-04-22)
###############################input data 
dir_path <- "/Users/xut2/Desktop/keras/"
dir_path_name <- dir(dir_path,pattern = "*.csv",full.names = T,recursive = T)
#dir_path_name
red <- read.csv(grep("winequality-red.csv", dir_path_name, value = T),
                check.names = T, header = T,stringsAsFactors = F, sep = ";")
#dim(red) #[1] 1599   12
#head(red)
colnames(red)
##  [1] "fixed.acidity"        "volatile.acidity"     "citric.acid"         
##  [4] "residual.sugar"       "chlorides"            "free.sulfur.dioxide" 
##  [7] "total.sulfur.dioxide" "density"              "pH"                  
## [10] "sulphates"            "alcohol"              "quality"
red <- red %>% mutate(good = ifelse(quality >= 6, 1, 0))  #factor must be encoded as numeric
dim(red); table(red$good) # 0   1 744 855
## [1] 1599   13
## 
##   0   1 
## 744 855
head(red)
##   fixed.acidity volatile.acidity citric.acid residual.sugar chlorides
## 1           7.4             0.70        0.00            1.9     0.076
## 2           7.8             0.88        0.00            2.6     0.098
## 3           7.8             0.76        0.04            2.3     0.092
## 4          11.2             0.28        0.56            1.9     0.075
## 5           7.4             0.70        0.00            1.9     0.076
## 6           7.4             0.66        0.00            1.8     0.075
##   free.sulfur.dioxide total.sulfur.dioxide density   pH sulphates alcohol
## 1                  11                   34  0.9978 3.51      0.56     9.4
## 2                  25                   67  0.9968 3.20      0.68     9.8
## 3                  15                   54  0.9970 3.26      0.65     9.8
## 4                  17                   60  0.9980 3.16      0.58     9.8
## 5                  11                   34  0.9978 3.51      0.56     9.4
## 6                  13                   40  0.9978 3.51      0.56     9.4
##   quality good
## 1       5    0
## 2       5    0
## 3       5    0
## 4       6    1
## 5       5    0
## 6       5    0
colnames(red)[ncol(red)] <- "num"
red$quality <- NULL
red$num <- as.factor(red$num)
class(red)
## [1] "data.frame"
set.seed(2023)
data_test_list <- row.names(stratified(red, "num", .3))
#data_test_list <- sample(unique(data$Mapping.ID),size = floor(length(unique(data$Mapping.ID))*0.3),replace = F)
# Randomly shuffle the data
testdata <- red[row.names(red) %in% data_test_list,]
traindata <- red[!row.names(red) %in% data_test_list,]
testdata <- unique(testdata)
traindata <- unique(traindata)
str(traindata);str(testdata)
## 'data.frame':    940 obs. of  12 variables:
##  $ fixed.acidity       : num  9.4 10.6 9.4 10.6 10.6 10.6 10.2 10.2 11.6 9.3 ...
##  $ volatile.acidity    : num  0.685 0.28 0.3 0.36 0.36 0.44 0.67 0.645 0.32 0.39 ...
##  $ citric.acid         : num  0.11 0.39 0.56 0.59 0.6 0.68 0.39 0.36 0.55 0.4 ...
##  $ residual.sugar      : num  2.7 15.5 2.8 2.2 2.2 4.1 1.9 1.8 2.8 2.6 ...
##  $ chlorides           : num  0.077 0.069 0.08 0.152 0.152 0.114 0.054 0.053 0.081 0.073 ...
##  $ free.sulfur.dioxide : num  6 6 6 6 7 6 6 5 35 10 ...
##  $ total.sulfur.dioxide: num  31 23 17 18 18 24 17 14 67 26 ...
##  $ density             : num  0.998 1.003 0.996 0.999 0.999 ...
##  $ pH                  : num  3.19 3.12 3.15 3.04 3.04 3.06 3.17 3.17 3.32 3.34 ...
##  $ sulphates           : num  0.7 0.66 0.92 1.05 1.06 0.66 0.47 0.42 0.92 0.75 ...
##  $ alcohol             : num  10.1 9.2 11.7 9.4 9.4 13.4 10 10 10.8 10.2 ...
##  $ num                 : Factor w/ 2 levels "0","1": 2 1 2 1 1 2 1 2 2 2 ...
## 'data.frame':    419 obs. of  12 variables:
##  $ fixed.acidity       : num  7.4 7.8 7.8 11.2 7.4 7.9 7.3 7.8 7.5 6.7 ...
##  $ volatile.acidity    : num  0.7 0.88 0.76 0.28 0.66 0.6 0.65 0.58 0.5 0.58 ...
##  $ citric.acid         : num  0 0 0.04 0.56 0 0.06 0 0.02 0.36 0.08 ...
##  $ residual.sugar      : num  1.9 2.6 2.3 1.9 1.8 1.6 1.2 2 6.1 1.8 ...
##  $ chlorides           : num  0.076 0.098 0.092 0.075 0.075 0.069 0.065 0.073 0.071 0.097 ...
##  $ free.sulfur.dioxide : num  11 25 15 17 13 15 15 9 17 15 ...
##  $ total.sulfur.dioxide: num  34 67 54 60 40 59 21 18 102 65 ...
##  $ density             : num  0.998 0.997 0.997 0.998 0.998 ...
##  $ pH                  : num  3.51 3.2 3.26 3.16 3.51 3.3 3.39 3.36 3.35 3.28 ...
##  $ sulphates           : num  0.56 0.68 0.65 0.58 0.56 0.46 0.47 0.57 0.8 0.54 ...
##  $ alcohol             : num  9.4 9.8 9.8 9.8 9.4 9.4 10 9.5 10.5 9.2 ...
##  $ num                 : Factor w/ 2 levels "0","1": 1 1 1 2 1 1 2 2 1 1 ...
#head(traindata)
#unique(traindata$num)
#table(traindata$num); table(testdata$num) #0   1 461 311 
#dim(testdata);dim(traindata) [1] 525  34  [1] 793  34
sum(row.names(traindata) %in% row.names(testdata))
## [1] 0
#################################
head(traindata)
##     fixed.acidity volatile.acidity citric.acid residual.sugar chlorides
## 480           9.4            0.685        0.11            2.7     0.077
## 481          10.6            0.280        0.39           15.5     0.069
## 482           9.4            0.300        0.56            2.8     0.080
## 483          10.6            0.360        0.59            2.2     0.152
## 484          10.6            0.360        0.60            2.2     0.152
## 485          10.6            0.440        0.68            4.1     0.114
##     free.sulfur.dioxide total.sulfur.dioxide density   pH sulphates alcohol num
## 480                   6                   31  0.9984 3.19      0.70    10.1   1
## 481                   6                   23  1.0026 3.12      0.66     9.2   0
## 482                   6                   17  0.9964 3.15      0.92    11.7   1
## 483                   6                   18  0.9986 3.04      1.05     9.4   0
## 484                   7                   18  0.9986 3.04      1.06     9.4   0
## 485                   6                   24  0.9970 3.06      0.66    13.4   1
traindata[, -ncol(traindata)] <- scale(traindata[, -ncol(traindata)])
testdata[, -ncol(testdata)] <- scale(testdata[, -ncol(testdata)])
traindata <- as.h2o(traindata)
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
testdata <- as.h2o(testdata)
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
colnames(traindata)
##  [1] "fixed.acidity"        "volatile.acidity"     "citric.acid"         
##  [4] "residual.sugar"       "chlorides"            "free.sulfur.dioxide" 
##  [7] "total.sulfur.dioxide" "density"              "pH"                  
## [10] "sulphates"            "alcohol"              "num"
################################build  model
h2o_fit <- h2o.deeplearning(
  x = 1:c(ncol(testdata)-1),
  y = ncol(testdata),
  training_frame = traindata, 
  hidden = c(100, 100, 100))
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
# Make predictions on the validation set
set.seed(2023)
############################
# Predict the test set probabilities
#For classification task,return the probability of a class
pred_prob_h2o <- h2o.predict(h2o_fit, testdata[, -ncol(testdata)])
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
roc_curve_h2o <- roc(as.numeric(as.matrix(testdata$num)), as.numeric(as.matrix(pred_prob_h2o$p1)), 
                     plot = TRUE, print.auc=TRUE)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases

roc_curve_h2o
## 
## Call:
## roc.default(response = as.numeric(as.matrix(testdata$num)), predictor = as.numeric(as.matrix(pred_prob_h2o$p1)),     plot = TRUE, print.auc = TRUE)
## 
## Data: as.numeric(as.matrix(pred_prob_h2o$p1)) in 226 controls (as.numeric(as.matrix(testdata$num)) 0) < 193 cases (as.numeric(as.matrix(testdata$num)) 1).
## Area under the curve: 0.7639
# Evaluate model on test data and print AUC
auc <- auc(roc_curve_h2o)
print(auc)
## Area under the curve: 0.7639
# Compute ROC curve
cutoff_h2o <- coords(roc_curve_h2o, "best", ret = "threshold")
cutoff_h2o 
##   threshold
## 1 0.6468509
# Predict classes based on optimal cutoff
pred_class_h2o <- ifelse(as.numeric(as.matrix(pred_prob_h2o$p1)) >= as.numeric(cutoff_h2o), 1, 0)
head(pred_class_h2o)
## [1] 0 0 0 1 0 0
#table(pred_class_h2o)
#table(as.numeric(as.matrix(pred$predict)))
#roc(as.numeric(as.matrix(testdata$num)), pred_class_h2o, plot = TRUE, print.auc=TRUE)
# Compute confusion matrix
conf_mat <- table(as.numeric(as.matrix(testdata$num)), pred_class_h2o)
# Print confusion matrix
print(conf_mat)
##    pred_class_h2o
##       0   1
##   0 147  79
##   1  49 144
##########################shutdown h2o
h2o.shutdown(prompt = F)
#########################
library(ggplot2)
# Create data frame for plotting
roc_df <- data.frame(fpr = roc_curve_h2o$specificities, tpr = roc_curve_h2o$sensitivities)
# Compute AUC
auc_val <- auc
# Plot ROC curve and AUC
ggplot(roc_df, aes(x = fpr, y = tpr)) +
  geom_line() +
  geom_abline(intercept = 1, slope = 1, linetype = "dashed") +
  ggtitle(paste0("ROC Curve (AUC = ", round(auc_val, 2), ")")) +
  xlab("False Positive Rate") +
  ylab("True Positive Rate") +
  theme_bw() + scale_x_reverse()

#REF https://srdas.github.io/DLBook/DeepLearningWithR.html#using-mxnet
#https://cran.r-project.org/web/packages/deepnet/deepnet.pdf