rm(list = ls())
#install.packages("keras")
library(neuralnet)
library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.2 ──
## ✔ ggplot2 3.4.2     ✔ purrr   1.0.1
## ✔ tibble  3.2.1     ✔ dplyr   1.1.2
## ✔ tidyr   1.3.0     ✔ stringr 1.5.0
## ✔ readr   2.1.4     ✔ forcats 0.5.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::compute() masks neuralnet::compute()
## ✖ dplyr::filter()  masks stats::filter()
## ✖ dplyr::lag()     masks stats::lag()
library(splitstackshape)
library(pROC)
## Type 'citation("pROC")' for a citation.
## 
## Attaching package: 'pROC'
## 
## The following objects are masked from 'package:stats':
## 
##     cov, smooth, var
#install.packages("caret")
#?install_keras() #Install TensorFlow and Keras, including all Python dependencies
#install_keras(tensorflow = "gpu")
###############################input data 
dir_path <- "/Users/xut2/Desktop/keras/"
dir_path_name <- dir(dir_path,pattern = "*.csv",full.names = T,recursive = T)
#dir_path_name
red <- read.csv(grep("winequality-red.csv", dir_path_name, value = T),
                check.names = T, header = T,stringsAsFactors = F, sep = ";")
#dim(red) #[1] 1599   12
#head(red)
colnames(red)
##  [1] "fixed.acidity"        "volatile.acidity"     "citric.acid"         
##  [4] "residual.sugar"       "chlorides"            "free.sulfur.dioxide" 
##  [7] "total.sulfur.dioxide" "density"              "pH"                  
## [10] "sulphates"            "alcohol"              "quality"
red <- red %>% mutate(good = ifelse(quality >= 6, 1, 0))  #factor must be encoded as numeric
dim(red); table(red$good) # 0   1 744 855
## [1] 1599   13
## 
##   0   1 
## 744 855
head(red)
##   fixed.acidity volatile.acidity citric.acid residual.sugar chlorides
## 1           7.4             0.70        0.00            1.9     0.076
## 2           7.8             0.88        0.00            2.6     0.098
## 3           7.8             0.76        0.04            2.3     0.092
## 4          11.2             0.28        0.56            1.9     0.075
## 5           7.4             0.70        0.00            1.9     0.076
## 6           7.4             0.66        0.00            1.8     0.075
##   free.sulfur.dioxide total.sulfur.dioxide density   pH sulphates alcohol
## 1                  11                   34  0.9978 3.51      0.56     9.4
## 2                  25                   67  0.9968 3.20      0.68     9.8
## 3                  15                   54  0.9970 3.26      0.65     9.8
## 4                  17                   60  0.9980 3.16      0.58     9.8
## 5                  11                   34  0.9978 3.51      0.56     9.4
## 6                  13                   40  0.9978 3.51      0.56     9.4
##   quality good
## 1       5    0
## 2       5    0
## 3       5    0
## 4       6    1
## 5       5    0
## 6       5    0
colnames(red)[ncol(red)] <- "num"
red$quality <- NULL
red$num <- as.numeric(red$num)
set.seed(2023)
data_test_list <- row.names(stratified(red, "num", .3))
#data_test_list <- sample(unique(data$Mapping.ID),size = floor(length(unique(data$Mapping.ID))*0.3),replace = F)
# Randomly shuffle the data
testdata <- red[row.names(red) %in% data_test_list,]
traindata <- red[!row.names(red) %in% data_test_list,]
testdata <- unique(testdata)
traindata <- unique(traindata)
str(traindata);str(testdata)
## 'data.frame':    940 obs. of  12 variables:
##  $ fixed.acidity       : num  9.4 10.6 9.4 10.6 10.6 10.6 10.2 10.2 11.6 9.3 ...
##  $ volatile.acidity    : num  0.685 0.28 0.3 0.36 0.36 0.44 0.67 0.645 0.32 0.39 ...
##  $ citric.acid         : num  0.11 0.39 0.56 0.59 0.6 0.68 0.39 0.36 0.55 0.4 ...
##  $ residual.sugar      : num  2.7 15.5 2.8 2.2 2.2 4.1 1.9 1.8 2.8 2.6 ...
##  $ chlorides           : num  0.077 0.069 0.08 0.152 0.152 0.114 0.054 0.053 0.081 0.073 ...
##  $ free.sulfur.dioxide : num  6 6 6 6 7 6 6 5 35 10 ...
##  $ total.sulfur.dioxide: num  31 23 17 18 18 24 17 14 67 26 ...
##  $ density             : num  0.998 1.003 0.996 0.999 0.999 ...
##  $ pH                  : num  3.19 3.12 3.15 3.04 3.04 3.06 3.17 3.17 3.32 3.34 ...
##  $ sulphates           : num  0.7 0.66 0.92 1.05 1.06 0.66 0.47 0.42 0.92 0.75 ...
##  $ alcohol             : num  10.1 9.2 11.7 9.4 9.4 13.4 10 10 10.8 10.2 ...
##  $ num                 : num  1 0 1 0 0 1 0 1 1 1 ...
## 'data.frame':    419 obs. of  12 variables:
##  $ fixed.acidity       : num  7.4 7.8 7.8 11.2 7.4 7.9 7.3 7.8 7.5 6.7 ...
##  $ volatile.acidity    : num  0.7 0.88 0.76 0.28 0.66 0.6 0.65 0.58 0.5 0.58 ...
##  $ citric.acid         : num  0 0 0.04 0.56 0 0.06 0 0.02 0.36 0.08 ...
##  $ residual.sugar      : num  1.9 2.6 2.3 1.9 1.8 1.6 1.2 2 6.1 1.8 ...
##  $ chlorides           : num  0.076 0.098 0.092 0.075 0.075 0.069 0.065 0.073 0.071 0.097 ...
##  $ free.sulfur.dioxide : num  11 25 15 17 13 15 15 9 17 15 ...
##  $ total.sulfur.dioxide: num  34 67 54 60 40 59 21 18 102 65 ...
##  $ density             : num  0.998 0.997 0.997 0.998 0.998 ...
##  $ pH                  : num  3.51 3.2 3.26 3.16 3.51 3.3 3.39 3.36 3.35 3.28 ...
##  $ sulphates           : num  0.56 0.68 0.65 0.58 0.56 0.46 0.47 0.57 0.8 0.54 ...
##  $ alcohol             : num  9.4 9.8 9.8 9.8 9.4 9.4 10 9.5 10.5 9.2 ...
##  $ num                 : num  0 0 0 1 0 0 1 1 0 0 ...
#head(traindata)
#unique(traindata$num)
#table(traindata$num); table(testdata$num) #0   1 461 311 
#dim(testdata);dim(traindata) [1] 525  34  [1] 793  34
sum(row.names(traindata) %in% row.names(testdata))
## [1] 0
#################################
head(traindata)
##     fixed.acidity volatile.acidity citric.acid residual.sugar chlorides
## 480           9.4            0.685        0.11            2.7     0.077
## 481          10.6            0.280        0.39           15.5     0.069
## 482           9.4            0.300        0.56            2.8     0.080
## 483          10.6            0.360        0.59            2.2     0.152
## 484          10.6            0.360        0.60            2.2     0.152
## 485          10.6            0.440        0.68            4.1     0.114
##     free.sulfur.dioxide total.sulfur.dioxide density   pH sulphates alcohol num
## 480                   6                   31  0.9984 3.19      0.70    10.1   1
## 481                   6                   23  1.0026 3.12      0.66     9.2   0
## 482                   6                   17  0.9964 3.15      0.92    11.7   1
## 483                   6                   18  0.9986 3.04      1.05     9.4   0
## 484                   7                   18  0.9986 3.04      1.06     9.4   0
## 485                   6                   24  0.9970 3.06      0.66    13.4   1
traindata[, -ncol(traindata)] <- scale(traindata[, -ncol(traindata)])
testdata[, -ncol(testdata)] <- scale(testdata[, -ncol(testdata)])
################################build  model
#?neuralnet
n <- names(traindata)
f <- as.formula(paste("num ~", paste(n[!n %in% "num"], collapse = " + ")))
nn_fit <- neuralnet(f, data = traindata, hidden = c(50,30), linear.output = FALSE)
#?neuralnet
############################
# Predict the test set probabilities
#For classification task,return the probability of a class
pred_prob_nn <- predict(nn_fit, testdata[, -ncol(testdata)])
roc_curve <- roc(testdata$num, pred_prob_nn[,1], plot = TRUE, print.auc=TRUE)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
roc_curve
## 
## Call:
## roc.default(response = testdata$num, predictor = pred_prob_nn[,     1], plot = TRUE, print.auc = TRUE)
## 
## Data: pred_prob_nn[, 1] in 226 controls (testdata$num 0) < 193 cases (testdata$num 1).
## Area under the curve: 0.6907
# Evaluate model on test data and print AUC
auc <- auc(roc_curve)
print(auc)
## Area under the curve: 0.6907
# Compute ROC curve
roc_curve <- roc(testdata$num, as.numeric(pred_prob_nn), plot = TRUE, print.auc=TRUE)
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases

cutoff_nn <- coords(roc_curve, "best", ret = "threshold")
cutoff_nn
##     threshold
## 1 0.007211559
# Predict classes based on optimal cutoff
pred_class <- ifelse(pred_prob_nn >= as.numeric(cutoff_nn), 1, 0)
head(pred_class)
##   [,1]
## 1    0
## 2    1
## 3    0
## 4    1
## 6    1
## 7    0
# Compute confusion matrix
conf_mat <- table(testdata$num, pred_class)
# Print confusion matrix
print(conf_mat)
##    pred_class
##       0   1
##   0 101 125
##   1  27 166
#########################
library(ggplot2)
# Create data frame for plotting
roc_df <- data.frame(fpr = roc_curve$specificities, tpr = roc_curve$sensitivities)
# Compute AUC
auc_val <- auc
# Plot ROC curve and AUC
ggplot(roc_df, aes(x = fpr, y = tpr)) +
  geom_line() +
  geom_abline(intercept = 1, slope = 1, linetype = "dashed") +
  ggtitle(paste0("ROC Curve (AUC = ", round(auc_val, 2), ")")) +
  xlab("False Positive Rate") +
  ylab("True Positive Rate") +
  theme_bw() + scale_x_reverse()

#REF https://srdas.github.io/DLBook/DeepLearningWithR.html#using-mxnet
#https://cran.r-project.org/web/packages/deepnet/deepnet.pdf