library(mlbench)
library(keras)
data2<-data(BreastCancer)
data2<-BreastCancer
data2<-na.omit(data2)
data2<-data2[-1]
data2<-as.matrix(data2)
data2[,10]<-as.numeric(as.factor(data2[,10]))
split<-sample(1:2,nrow(data2),prob = c(0.8,0.2),replace = T)
train<-data2[split==1,]
test<-data2[split==2,]
trainx<-train[,1:9]
train_y<-as.numeric(as.factor(train[,10]))
trian_y<-train_y-1
y_train<- to_categorical(trian_y)
testx<-test[,1:9]
test_y<-as.numeric(as.factor(test[,10]))
test_y<-test_y-1
y_test<-to_categorical(test_y)
model<-keras_model_sequential()
model %>%
layer_dense(units =512,activation = 'relu',input_shape = 9) %>%
layer_dense(units = 10, activation = 'relu')%>%
layer_dense(units = 10, activation = 'relu')%>%
layer_dense(units = 2, activation = 'sigmoid')
model %>% compile(
loss='binary_crossentropy',
optimizer='adam',
metrics=c('accuracy')
)
summary(model)
## ___________________________________________________________________________
## Layer (type) Output Shape Param #
## ===========================================================================
## dense (Dense) (None, 512) 5120
## ___________________________________________________________________________
## dense_1 (Dense) (None, 10) 5130
## ___________________________________________________________________________
## dense_2 (Dense) (None, 10) 110
## ___________________________________________________________________________
## dense_3 (Dense) (None, 2) 22
## ===========================================================================
## Total params: 10,382
## Trainable params: 10,382
## Non-trainable params: 0
## ___________________________________________________________________________
trainx<-as.matrix(trainx)
class(trainx)
## [1] "matrix"
history <- model %>% fit(trainx,y_train,epochs=50,batch_size=5,validation_split=0.2)
plot(history)

score <- model %>% evaluate(testx, y_test)
cat('Test loss:', score$loss, "\n")
## Test loss: 0.06326917
cat('Test accuracy:', score$acc, "\n")
## Test accuracy: 0.974359
sam<-predict_classes(model,testx)
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
table(sam,test_y)
## test_y
## sam 0 1
## 0 74 1
## 1 2 40
85+54
## [1] 139