rm(list = ls())
###############################input data_1 
dir_path <- "C:\\Users\\liyix\\OneDrive\\Desktop\\data_for_model\\"
dir_path_name <- dir(dir_path,pattern = ".*.csv",full.names = T)
#dir_path_name
###############################merge data 
data_train <- read.csv(grep("training_data.csv",dir_path_name,value = T),header = T,stringsAsFactors = F)
#dim(data_train) #[1] 2357  148
#View(data_train)
data_test <- read.csv(grep("test_data.csv",dir_path_name,value = T),header = T,stringsAsFactors = F)
#dim(data_test) #[1] 1009  150
#table(data_train$Freq) #    0    1 2057  300 
#table(data_test$Freq) #   0   1 894 115 
#View(head(data_test))
#setdiff(colnames(data_test),colnames(data_train)) #[1] "Drug"  "Virus"
#setdiff(colnames(data_train),colnames(data_test)) #character(0) 
#colnames(data_train)[ncol(data_train)] #[1] "Freq"
################################
#data_new <- data_test
#data_train <- unique(data_train)
###model
###################################model
library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
library(ROCR)
library(pROC)
## Type 'citation("pROC")' for a citation.
## 
## 载入程辑包:'pROC'
## The following objects are masked from 'package:stats':
## 
##     cov, smooth, var
library(ROSE)
## Loaded ROSE 0.0-4
library(DMwR)
## 载入需要的程辑包:lattice
## 载入需要的程辑包:grid
## Registered S3 method overwritten by 'quantmod':
##   method            from
##   as.zoo.data.frame zoo
#colnames(data_train)[ncol(data_train)] #[1] "Freq"
colnames(data_train)[ncol(data_train)] <- "num"
data_train <- data.frame(sapply(data_train, as.numeric))
data_train$num <- as.factor(data_train$num)
#str(data_train$num)
############################################################################test_cross_validation
#dim(data_train) #[1] 2356  148
#View(head(data_train))
#colnames(data_train)
#table(data_train$num) # 0    1  2057  300 
testdata_ori <- data_train[sample(nrow(data_train),size = floor(nrow(data_train)*0.3),replace = F),]
traindata_ori <- data_train[-match(row.names(testdata_ori),row.names(data_train)),]
#intersect(row.names(testdata_ori),row.names(traindata_ori)) #character(0)
#table(testdata_ori$num) # 0   1 625  81 
#table(traindata_ori$num) # 0    1  1444  206 
#dim(testdata_ori);dim(traindata_ori) #[1] 706 148 [1] 1650  148
traindata_ori <- ovun.sample(num ~ ., data = traindata_ori,  method = "over")$data
#table(traindata_ori$num) #  0    1 1444 1448 
fit_rf_ori <- randomForest(num ~ .,data = traindata_ori)  
#### auc_rf
n_col_newdata <- ncol(data_train)
pre_auc_rf_ori <- predict(fit_rf_ori, testdata_ori[,-n_col_newdata], type = "prob")
pred_auc_rf_ori <- prediction(pre_auc_rf_ori[,2],testdata_ori[,n_col_newdata])
auc_rf_ori <- performance(pred_auc_rf_ori,"auc")@y.values[[1]]
auc_rf_ori #[1] 0.9793877
## [1] 0.9665492
pre_probability_ori <- data.frame(pre_auc_rf_ori)
pre_classification_ori <- predict(fit_rf_ori, testdata_ori[,-ncol(data_test)], type="class")
#class(pre_classification_ori)
pre_probability_ori$num_1 <- as.character(pre_classification_ori)
pre_probability_ori$num <- testdata_ori$num
#head(pre_probability_ori)
#table(pre_probability_ori$num,pre_probability_ori$num_1)
modelroc_rf_ori <- pROC::roc(testdata_ori[,n_col_newdata],as.numeric(pre_auc_rf_ori[,2]))
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
cutoff_rf_ori <- coords(modelroc_rf_ori, "best", ret = "threshold") #[1] 0.265
predict.results_rf_ori <- ifelse(pre_auc_rf_ori[,2] > as.numeric(cutoff_rf_ori[1]),"1","0")
freq_default_rf_ori <- table(predict.results_rf_ori,testdata_ori[,n_col_newdata])
Sensitivity_rf_ori <- freq_default_rf_ori[1,1]/sum(freq_default_rf_ori[,1])
Specificity_rf_ori <- freq_default_rf_ori[2,2]/sum(freq_default_rf_ori[,2])
accuracy_rf_ori <- mean(testdata_ori[,n_col_newdata] == predict.results_rf_ori)
Balanced_accuracy_rf_ori <- (Sensitivity_rf_ori+Specificity_rf_ori)/2
Balanced_accuracy_rf_ori #[1] 0.954963
## [1] 0.9169848
library(caret)
## 载入需要的程辑包:ggplot2
## 
## 载入程辑包:'ggplot2'
## The following object is masked from 'package:randomForest':
## 
##     margin
#confusionMatrix(factor(predict.results_rf_ori),factor(testdata_ori[,n_col_newdata]))
#confusionMatrix(factor(pre_probability_ori$num_1),factor(testdata_ori[,n_col_newdata]))
#####0    1
#0 1350   15
#1  116   13
#################################################################################################
##svm
data_train <- ovun.sample(num ~ ., data = data_train,  method = "over")$data
#dim(data_train) #[1] 4129  148
#################################################probability
#fit_probability <- svm(num ~ .,data = data_train, scale = F, probability=TRUE )
#dim(data_train) #[1] 4981  116
#table(data_train$num) #4515  466 
fit_probability  <- randomForest(num ~ .,data = data_train) 
#### AUC
#names(data_test)[ncol(data_test)]
#dim(data_test) #[1] 1009  150
#colnames(data_test)
pre_auc_probability <- predict(fit_probability, data_test[,-c(ncol(data_test),ncol(data_test)-1, ncol(data_test)-2)],  type="prob")
pre_probability <- data.frame(pre_auc_probability)
#head(pre_probability)
########################################
colnames(pre_probability) <- gsub("X1", "probability_1", colnames(pre_probability) )
colnames(pre_probability) <- gsub("X0", "probability_0", colnames(pre_probability) )
#dim(pre_probability) #[1] 1009    2
#colnames(data_test)
data_1 <- cbind(pre_probability, data_test[,c(ncol(data_test),ncol(data_test)-1, ncol(data_test)-2)])
#head(data_1)
#table(cut(pre_probability$probability_1,breaks = c(-0.1,.5,1)))
#head(pre_probability)
#write.table(pre_probability,paste0(dir_path,Sys.Date(),"-ecfp4_auc_rf_original.txt"),row.names = F,sep = "\t")
#write.csv(pre_probability,paste0(dir_path,Sys.Date(),"-ecfp4_auc_RF_original_0_52.csv"),row.names = F)
pred_auc <- prediction(data_1$probability_1,data_1$Freq)
auc_roc <- performance(pred_auc,"auc")@y.values[[1]]
auc_roc #[1] 0.9817041
## [1] 0.9794621
write.csv(data_1,paste0(dir_path,Sys.Date(),"-RF_upsample_probability.csv"),row.names = F)
#######################plot
#View(data_1)
head(data_1)
##   probability_0 probability_1 Freq         Virus                  Drug
## 1         0.896         0.104    0         HIV-2               arbidol
## 2         0.902         0.098    0 hCoV-19-Alpha        propagermanium
## 3         0.982         0.018    0 hCoV-19-Delta            telaprevir
## 4         0.922         0.078    0        HPV-11 tenofovir-alafenamide
## 5         0.966         0.034    0           HBV            cobicistat
## 6         0.948         0.052    0         HCV-7           chloroquine
library(ggplot2)
data_1$Freq <- factor(data_1$Freq)
ggplot(data_1) +
  stat_density(aes(x = probability_1, color = Freq),size = 1,
               alpha=0.5, bw = 0.01,  geom="line",position="identity") +
  scale_y_discrete(expand = c(0.01, 0)) +
  scale_x_continuous(expand = c(0, 0),limits = c(-0.05, 1.05),
                     breaks = seq(0,1,.2)) +
  labs(colour="Antiviral drug",x = "Probability", y = "Density",title = "RF")+
  geom_vline(aes(xintercept=0.5),
             color="gray80", linetype="dashed", size=0.5)+
  theme(panel.spacing = unit(0.1, "cm"),
        legend.position= "top",
        legend.key = element_rect(colour = NA, fill = NA),
        legend.text=element_text(size=14),
        legend.title = element_text(size=14),
        axis.ticks = element_line(colour = "black", 
                                  size = 0.5, linetype = "solid"),
        axis.line = element_line(colour = "black", 
                                 size = 0.5, linetype = "solid"),
        axis.text =element_text(face="plain", color="black", family = "sans",
                                size=14,angle = 0),
        panel.background = element_rect(fill = "white",
                                        colour = "white",
                                        size = 0.5, linetype = "solid"),
        panel.grid.major = element_line(size = 1, linetype = 'dashed',
                                        colour = "white"),
        axis.title = element_text(color="black", size=14, face="plain",family="sans")) +
  scale_color_manual(name = "Antiviral drug",values=c("#3b58a7","#90278e"))

##########################output
ggsave(filename = paste0(Sys.Date(),"-probability_plot_SVM.tif"), plot = last_plot(), 
       device = "tiff", path = dir_path,
       scale = 1, width = 16, height = 12, units = "cm",
       dpi = 300, limitsize = TRUE, compression = "lzw")
#View(data_1)
data_11 <- data_1[data_1$Freq == 1, ]
data_22 <- data_1[data_1$Freq != 1, ]
#dim(data_1) #[1] 115   5
table(cut(data_11$probability_1, breaks = c(0,0.5,1), include.lowest = T))
## 
## [0,0.5] (0.5,1] 
##      12     103
table(cut(data_22$probability_1, breaks = c(0,0.5,1), include.lowest = T))
## 
## [0,0.5] (0.5,1] 
##     884      10