数据地址:http://archive.ics.uci.edu/ml/datasets/image+segmentation
library(keras)
## Warning: package 'keras' was built under R version 3.4.4
library(ggplot2)
## Warning: package 'ggplot2' was built under R version 3.4.2
library(pheatmap)
## Warning: package 'pheatmap' was built under R version 3.4.4
## 数据准备
imsedata <- read.csv("image segmentation.csv",header = T,row.names = NULL)
imsedata$REGION.PIXEL.COUNT <- NULL
imsedata$row.names <- as.integer(as.factor(imsedata$row.names))-1
head(imsedata)
summary(imsedata)
## row.names REGION.CENTROID.COL REGION.CENTROID.ROW SHORT.LINE.DENSITY.5
## Min. :0 Min. : 1.0 Min. : 11.0 Min. :0.000000
## 1st Qu.:1 1st Qu.: 60.5 1st Qu.: 81.5 1st Qu.:0.000000
## Median :3 Median :123.5 Median :121.5 Median :0.000000
## Mean :3 Mean :124.6 Mean :122.8 Mean :0.008466
## 3rd Qu.:5 3rd Qu.:189.8 3rd Qu.:174.5 3rd Qu.:0.000000
## Max. :6 Max. :252.0 Max. :250.0 Max. :0.111111
## SHORT.LINE.DENSITY.2 VEDGE.MEAN VEDGE.SD
## Min. :0.000000 Min. : 0.0000 Min. : 0.0000
## 1st Qu.:0.000000 1st Qu.: 0.6667 1st Qu.: 0.4009
## Median :0.000000 Median : 1.2222 Median : 0.8287
## Mean :0.006349 Mean : 1.9251 Mean : 5.7195
## 3rd Qu.:0.000000 3rd Qu.: 1.8889 3rd Qu.: 1.6766
## Max. :0.222222 Max. :25.5000 Max. :572.9964
## HEDGE.MEAN HEDGE.SD INTENSITY.MEAN RAWRED.MEAN
## Min. : 0.0000 Min. : 0.0000 Min. : 0.000 Min. : 0.00
## 1st Qu.: 0.7778 1st Qu.: 0.4108 1st Qu.: 6.454 1st Qu.: 7.00
## Median : 1.3889 Median : 0.9132 Median : 21.315 Median : 18.61
## Mean : 2.6042 Mean : 11.6384 Mean : 37.091 Mean : 32.97
## 3rd Qu.: 2.5972 3rd Qu.: 1.9805 3rd Qu.: 52.630 3rd Qu.: 46.75
## Max. :44.7222 Max. :1386.3292 Max. :143.444 Max. :136.89
## RAWBLUE.MEAN RAWGREEN.MEAN EXRED.MEAN EXBLUE.MEAN
## Min. : 0.000 Min. : 0.000 Min. :-48.222 Min. :-9.667
## 1st Qu.: 8.278 1st Qu.: 3.806 1st Qu.:-18.111 1st Qu.: 4.111
## Median : 26.833 Median : 20.000 Median :-10.333 Median :19.556
## Mean : 44.011 Mean : 34.294 Mean :-12.370 Mean :20.760
## 3rd Qu.: 64.194 3rd Qu.: 46.472 3rd Qu.: -4.667 3rd Qu.:34.333
## Max. :150.889 Max. :142.556 Max. : 5.778 Max. :78.778
## EXGREEN.MEAN VALUE.MEAN SATURATION.MEAN HUE.MEAN
## Min. :-30.556 Min. : 0.00 Min. :0.0000 Min. :-2.531
## 1st Qu.:-15.750 1st Qu.: 10.53 1st Qu.:0.2757 1st Qu.:-2.187
## Median : -9.889 Median : 28.39 Median :0.3655 Median :-2.044
## Mean : -8.390 Mean : 44.89 Mean :0.4232 Mean :-1.340
## 3rd Qu.: -3.722 3rd Qu.: 64.19 3rd Qu.:0.5397 3rd Qu.:-1.430
## Max. : 21.889 Max. :150.89 Max. :1.0000 Max. : 2.865
table(imsedata$row.names)
##
## 0 1 2 3 4 5 6
## 30 30 30 30 30 30 30
## dataframe to matrix
imsedata <- as.matrix(imsedata)
## 数据切分
set.seed(123)
index <- sample(nrow(imsedata),size = round(nrow(imsedata)*0.7))
train_x <- imsedata[index,2:19]
train_y <- to_categorical(imsedata[index,1],7)
test_x <- imsedata[-index,2:19]
test_y <- to_categorical(imsedata[-index,1],7)
## 数据标准化
imsedatascale <- apply(imsedata[,2:19], 2, scale)
## 标准化后数据切分
train_xsc <- imsedatascale[index,]
test_xsc <- imsedatascale[-index,]
model <- keras_model_sequential()
model %>%
layer_dense(units = 64,activation = "relu",input_shape = 18,name = "den1")%>%
layer_dropout(rate = 0.25)%>%
layer_dense(units = 32,activation = "relu",name = "den2")%>%
layer_gaussian_dropout(rate = 0.25)%>%
layer_dense(units = 7,activation = "softmax")
summary(model)
## ___________________________________________________________________________
## Layer (type) Output Shape Param #
## ===========================================================================
## den1 (Dense) (None, 64) 1216
## ___________________________________________________________________________
## dropout_1 (Dropout) (None, 64) 0
## ___________________________________________________________________________
## den2 (Dense) (None, 32) 2080
## ___________________________________________________________________________
## gaussian_dropout_1 (GaussianDrop (None, 32) 0
## ___________________________________________________________________________
## dense_1 (Dense) (None, 7) 231
## ===========================================================================
## Total params: 3,527
## Trainable params: 3,527
## Non-trainable params: 0
## ___________________________________________________________________________
## compile
model%>%compile(
loss = "categorical_crossentropy",
optimizer = optimizer_adam(),
metrics = c("accuracy")
)
##标准化前数据的训练结果
mod_history <- model%>% fit(train_x,train_y,epochs = 100,batch_size = 8,validation_split = 0.2,verbose = 0)
## 可视化训练过程
plot(mod_history)+
theme_bw()+ggtitle("Don't scale")
## 预测在测试集上的准确度
model %>% evaluate(test_x,test_y)
## $loss
## [1] 1.14371
##
## $acc
## [1] 0.8095238
##标准化后数据的训练结果
mod_historysc <- model%>% fit(train_xsc,train_y,epochs = 100,batch_size = 8,validation_split = 0.2,verbose=0)
## 可视化训练过程
plot(mod_historysc)+
theme_bw()+ggtitle("scale")
## 预测在测试集上的准确度
model %>% evaluate(test_xsc,test_y)
## $loss
## [1] 0.3914827
##
## $acc
## [1] 0.8412698
model_we <- get_weights(model)
length(model_we)
## [1] 6
dim(model_we[[1]])
## [1] 18 64
pheatmap(model_we[[1]],cluster_rows = F,cluster_cols = F,labels_row = 1:18,labels_col = 1:64,
main = "layder 1 weight")
dim(model_we[[3]])
## [1] 64 32
pheatmap(model_we[[3]],cluster_rows = F,cluster_cols = F,labels_row = 1:64,labels_col = 1:32,
main = "layder 2 weight")
dim(model_we[[5]])
## [1] 32 7
pheatmap(model_we[[5]],cluster_rows = F,cluster_cols = F,labels_row = 1:32,labels_col = 1:7,
main = "layder 3 weight")