source("D:/R/R/NN/plot.nn.r") # отрисовка сети
data_diabetes = read.csv("D:/R/R/NN/prima-indians-diabetes.csv",header = FALSE, encoding = 'UTF-8')
#V1. Число беременностей (все пациенты из источника – женщины не моложе 21 года индийской народности пима).
#V2. Концентрация глюкозы в плазме через 2 часа после введения в пероральном глюкозотолерантном тесте
#V3. Диастолическое артериальное давление (мм рт. ст.).
#V4. Толщина кожной складки в районе трицепса (мм).
#V5. Концентрация инсулина в сыворотке крови (мкЕд/мл).
#V6. Индекс массы тела (вес в кг/(рост в м)^2).
#V7. Функция, описывающая генетическую предрасположенность к диабету (diabetes pedegree).
#V8. Возраст (годы).
#V9. Диагноз - целевой, показывает, наблюдался ли у пациента сахарный диабет или нет (соответственно, 1 или 0).
#install (if not already installed) and load data.table package
if(!require(data.table)){install.packages('data.table')}
## Loading required package: data.table
setnames(data_diabetes, old=c("V1","V2","V6","V8", "V9"),
new=c("Pregnancy", "Glucose", "BMI", "Age", "diagnosis"))
#view new column names
names(data_diabetes)
## [1] "Pregnancy" "Glucose" "V3" "V4" "V5" "BMI"
## [7] "V7" "Age" "diagnosis"
data_diabetes$diagnosis = factor(data_diabetes$diagnosis,
levels = c(1, 0),
labels = c("diabetes", "no diabetes"))
table(data_diabetes$diagnosis)
##
## diabetes no diabetes
## 268 500
#summary(data_diabetes[-9])
summary(data_diabetes[1:8])
## Pregnancy Glucose V3 V4
## Min. : 0.000 Min. : 0.0 Min. : 0.00 Min. : 0.00
## 1st Qu.: 1.000 1st Qu.: 99.0 1st Qu.: 62.00 1st Qu.: 0.00
## Median : 3.000 Median :117.0 Median : 72.00 Median :23.00
## Mean : 3.845 Mean :120.9 Mean : 69.11 Mean :20.54
## 3rd Qu.: 6.000 3rd Qu.:140.2 3rd Qu.: 80.00 3rd Qu.:32.00
## Max. :17.000 Max. :199.0 Max. :122.00 Max. :99.00
## V5 BMI V7 Age
## Min. : 0.0 Min. : 0.00 Min. :0.0780 Min. :21.00
## 1st Qu.: 0.0 1st Qu.:27.30 1st Qu.:0.2437 1st Qu.:24.00
## Median : 30.5 Median :32.00 Median :0.3725 Median :29.00
## Mean : 79.8 Mean :31.99 Mean :0.4719 Mean :33.24
## 3rd Qu.:127.2 3rd Qu.:36.60 3rd Qu.:0.6262 3rd Qu.:41.00
## Max. :846.0 Max. :67.10 Max. :2.4200 Max. :81.00
######################################################
set.seed(42)
indicate = sample(2, nrow(data_diabetes),
replace=TRUE, prob=c(0.75, 0.25))
diabetes_train = data_diabetes[indicate == 1, ] #75% of sample
#diabetes_train
diabetes_test = data_diabetes[indicate == 2, ] #25% of sample
#########################################
library(neuralnet)
model = neuralnet(diagnosis ~ Glucose + BMI + Age + Pregnancy,
data = diabetes_train, hidden = 5)
#neuralnet(formula, data, hidden = 1, threshold = 0.01,
# stepmax = 1e+05, rep = 1, startweights = NULL,
# learningrate.limit = NULL,
# learningrate.factor = list(minus = 0.5, plus = 1.2),
# learningrate = NULL, lifesign = "none",
# lifesign.step = 1000, algorithm = "rprop+", err.fct = "sse",
# act.fct = "logistic", linear.output = TRUE, exclude = NULL,
# constant.weights = NULL, likelihood = FALSE)
model$result.matrix
## [,1]
## error 9.239583e+01
## reached.threshold 8.873085e-03
## steps 5.155600e+04
## Intercept.to.1layhid1 6.257237e+00
## Glucose.to.1layhid1 -2.620604e-02
## BMI.to.1layhid1 -6.132968e-02
## Age.to.1layhid1 -8.530991e-03
## Pregnancy.to.1layhid1 -8.009666e-02
## Intercept.to.1layhid2 -2.865044e+01
## Glucose.to.1layhid2 4.141174e-01
## BMI.to.1layhid2 -6.293199e-01
## Age.to.1layhid2 8.416744e-01
## Pregnancy.to.1layhid2 2.385949e+01
## Intercept.to.1layhid3 8.792873e+01
## Glucose.to.1layhid3 1.522377e+00
## BMI.to.1layhid3 -2.957006e+00
## Age.to.1layhid3 -1.202493e+00
## Pregnancy.to.1layhid3 -5.061293e+00
## Intercept.to.1layhid4 7.014208e+00
## Glucose.to.1layhid4 1.334913e+00
## BMI.to.1layhid4 -1.410083e-01
## Age.to.1layhid4 5.814709e-02
## Pregnancy.to.1layhid4 2.998551e-01
## Intercept.to.1layhid5 -2.285281e+01
## Glucose.to.1layhid5 -7.544407e-01
## BMI.to.1layhid5 3.130489e-01
## Age.to.1layhid5 5.864803e-01
## Pregnancy.to.1layhid5 3.373898e+00
## Intercept.to.diabetes -5.043365e-01
## 1layhid1.to.diabetes -1.118985e+00
## 1layhid2.to.diabetes 7.110318e-01
## 1layhid3.to.diabetes 4.184175e-01
## 1layhid4.to.diabetes 4.047886e-01
## 1layhid5.to.diabetes 1.459651e+00
## Intercept.to.no diabetes -6.016502e-01
## 1layhid1.to.no diabetes 1.118992e+00
## 1layhid2.to.no diabetes -7.204854e-01
## 1layhid3.to.no diabetes -4.175618e-01
## 1layhid4.to.no diabetes 1.709781e+00
## 1layhid5.to.no diabetes -1.450107e+00
model$weights
## [[1]]
## [[1]][[1]]
## [,1] [,2] [,3] [,4] [,5]
## [1,] 6.257237428 -28.6504387 87.928728 7.01420848 -22.8528118
## [2,] -0.026206039 0.4141174 1.522377 1.33491259 -0.7544407
## [3,] -0.061329682 -0.6293199 -2.957006 -0.14100825 0.3130489
## [4,] -0.008530991 0.8416744 -1.202493 0.05814709 0.5864803
## [5,] -0.080096658 23.8594942 -5.061293 0.29985505 3.3738978
##
## [[1]][[2]]
## [,1] [,2]
## [1,] -0.5043365 -0.6016502
## [2,] -1.1189845 1.1189923
## [3,] 0.7110318 -0.7204854
## [4,] 0.4184175 -0.4175618
## [5,] 0.4047886 1.7097815
## [6,] 1.4596514 -1.4501073
plot.nn(model)
net.prob = compute(model, diabetes_test)$net.result
#net.prob
y_pred = c("diabetes", "no diabetes")[apply(net.prob, 1, which.max)]
#y_pred
table(Прогноз = y_pred, Факт = diabetes_test$diagnosis)
## Факт
## Прогноз diabetes no diabetes
## diabetes 44 14
## no diabetes 25 106
Accuracy = mean(y_pred == diabetes_test$diagnosis)
Accuracy
## [1] 0.7936508
###################################################
library(gmodels)
CrossTable(x = diabetes_test$diagnosis, y = y_pred, prop.chisq=FALSE)
##
##
## Cell Contents
## |-------------------------|
## | N |
## | N / Row Total |
## | N / Col Total |
## | N / Table Total |
## |-------------------------|
##
##
## Total Observations in Table: 189
##
##
## | y_pred
## diabetes_test$diagnosis | diabetes | no diabetes | Row Total |
## ------------------------|-------------|-------------|-------------|
## diabetes | 44 | 25 | 69 |
## | 0.638 | 0.362 | 0.365 |
## | 0.759 | 0.191 | |
## | 0.233 | 0.132 | |
## ------------------------|-------------|-------------|-------------|
## no diabetes | 14 | 106 | 120 |
## | 0.117 | 0.883 | 0.635 |
## | 0.241 | 0.809 | |
## | 0.074 | 0.561 | |
## ------------------------|-------------|-------------|-------------|
## Column Total | 58 | 131 | 189 |
## | 0.307 | 0.693 | |
## ------------------------|-------------|-------------|-------------|
##
##