knitr::opts_chunk$set(comment=NULL)
setwd("E:\\study\\kaggle\\mushroom-classification")
rm(list = ls())
library(tidyverse)
library(rpart) #决策树
library(rpart.plot)
library(vcd) #马赛克图
library(grid)
library(caret) #交叉验证

研究背景及意义

2018年《中国卫生健康统计年鉴》的数据显示,2016年全国因食用毒蘑菇、菜豆等动植物中毒的人数共7342人,其中58.8%因误食毒蘑菇中毒。而2017年全国因误食毒蘑菇中毒的事件不减反增,从2016年的991起增加到1410起,中毒人数则从4230人增加至5481人,分别比上年增加了42.3%和29.6%。此种现象的产生很可能是由于毒蘑菇辨别方法的误导。一个典型的错误辨别方法是:颜色鲜艳的蘑菇有毒,颜色普通的蘑菇无毒。然而很多色彩不鲜艳的蘑菇,比如白色的致命白毒伞、白毒鹅膏菌和白霜杯伞等,都是毒蘑菇;而如橙盖鹅膏、鸡油菌、金顶侧耳、双色牛肝菌和正红菇等,都是颜色鲜艳的食用菌。

Kaggle的Mushroom Classification数据集提取自《奥杜邦学会北美蘑菇实地指南》(1981年)。每种蘑菇都被确定为绝对可食用、肯定有毒或不确定,数据集中将后两类结合在一起定义为有毒。此外,每种蘑菇都有22个属性,这些属性给出了蘑菇的菌盖、菌褶、菌柄和菌环等各个部位的信息。笔者认为通过分析蘑菇各部位的形状、颜色、气味等属性,判断哪种蘑菇更可能有毒可以帮助减少因误食毒蘑菇中毒的事件,并破除流传甚广的各种毒蘑菇辨别谣言。

数据处理

数据描述

mushroom<-read.csv("mushrooms.csv",header=T)
dim(mushroom)
names(mushroom)
str(mushroom)

本文所用数据来自Kaggle的Mushroom Classification数据集。该数据集有8124个观测值,23个变量。其中,变量class表示是否有毒,水平“e”表示可食用,“p”表示有毒;其余变量都是分类变量,表示蘑菇各部位的形状、颜色、气味等属性,全都用字母表示其水平,如下所示:

变量信息
变量分类 变量名称 变量描述
毒性信息 class 是否有毒(“e”表示可食用,“p”表示有毒)
菌盖信息 cap-shape 菌盖形状
cap-surface 菌盖表面性质
cap-color 菌盖颜色
veil.type 菌盖面类型
veil.color 菌盖面颜色
菌褶信息 gill.attachment 菌褶着生状态
gill.spacing 菌褶间隔
gill.size 菌褶大小
gill.color 菌褶颜色
菌柄信息 stalk.shape 菌柄形状
stalk.root 菌柄根性质
stalk.surface.above.ring 菌柄环上表面
stalk.surface.below.ring 菌柄环下表面
stalk.color.above.ring 菌柄环上颜色
stalk.color.below.ring 菌柄环下颜色
菌环信息 ring.number 菌环数目
ring.type 菌环类型
其他 bruises 菌体
odor 气味
spore.print.color 孢子印颜色
population 人群分布
habitat 环境

分析时将变量class看作因变量,其余变量作为自变量。这是一个二分类问题,而且由于自变量全是分类变量,因此logistic回归、支持向量机和K最近邻方法等基于数量变量的方法都不能正常运行,本文仅采用决策树分类方法进行分析。

数据预处理

na_flag=map(mushroom,~sum(is.na(.)))
na_flag[which(na_flag!=0)]
unique=map(mushroom,~length(unique(.)))
unique[which(unique==1)]
table1=data.frame(table(mushroom$stalk.root))
table1=rename(table1,level=Var1,number=Freq)
knitr::kable(table1,caption="变量stalk.root各取值水平下的观测值数量")
变量stalk.root各取值水平下的观测值数量
level number
? 2480
b 3776
c 556
e 1120
r 192

使用is.na()函数查看数据集可知,该数据集无明确缺失值。但是查看变量stalk.root的结构可知,它有一个水平为“?”,代表数据缺失,由表2可知,缺失数据达2480个,占观测值总量的30.5%,且无法插补,因此删去变量stalk.root。

table2=data.frame(table(mushroom$veil.type))
table2=rename(table2,level=Var1,number=Freq)
knitr::kable(table2,caption="变量veil.type各取值水平下的观测值数量")
变量veil.type各取值水平下的观测值数量
level number
p 8124

此外,变量veil.type仅有1个水平,无法作为自变量进行分析,因此删去变量veil.type。

mushroom=subset(mushroom,select=-c(veil.type,stalk.root))

处理后的数据集有8124个观测值,21个变量。

划分训练集和测试集

table3=data.frame(prop.table(table(mushroom$class)))
table3=rename(table3,level=Var1,percent=Freq)
knitr::kable(table3,caption="变量class各取值水平下的观测值数量")
变量class各取值水平下的观测值数量
level percent
e 0.5179714
p 0.4820286

由表4可知,该数据集中可食用的蘑菇和有毒的蘑菇分别占51.8%和48.2%,数据十分平衡。采用简单随机抽样划分训练集和测试集。利用函数sample()随机抽取80%的数据作为训练集,剩余的数据为测试集。

len=dim(mushroom)[1]
set.seed(123)
test_id=sample(1:len,len*0.2,replace=F)
test_data=mushroom[test_id,]
train_data=mushroom[-test_id,]
dim(train_data)
dim(test_data)

得到的训练集和测试集分别有6500个和1624个观测值。

特征分析

在数据建模时,使用特征工程处理数据,能够减少算法模型受到的噪声干扰,从而更好地找出数据的潜在趋势。事实上,好的特征甚至能够帮我们实现使用简单的模型达到很好的效果。

相关性分析

write.csv(train_data,file="train_data.csv",row.names=F)
train_data1<-read.csv("train_data.csv",header=F)
train_data2=train_data
train_data2$class=as.numeric(train_data2$class)
train_data2$class[which(train_data2$class==1)]=0
train_data2$class[which(train_data2$class==2)]=1 #0可食用,1有毒

笔者以各变量的不同取值水平中毒蘑菇的占比来刻画变量水平与毒性的相关性,如果毒蘑菇占比小于25%或大于75%,则该变量水平与毒性高度相关。以变量cap.shape(菌盖形状)为例:

x=data.frame(table(train_data2$cap.shape))
y=data.frame(class=train_data2$class,cap.shape=train_data2$cap.shape) %>%
  group_by(cap.shape) %>%
  summarise(posion_number=sum(class)) %>%
  left_join(.,x,by=c("cap.shape"="Var1"))
z=y %>%
  mutate(danger=posion_number/Freq) %>%
  rename(sum_number=Freq)
knitr::kable(z,caption="不同菌盖形状毒蘑菇占比")
不同菌盖形状毒蘑菇占比
cap.shape posion_number sum_number danger
b 34 368 0.0923913
c 2 2 1.0000000
f 1286 2559 0.5025401
k 484 670 0.7223881
s 0 28 0.0000000
x 1351 2873 0.4702402

由表5可知,当cap.shape=b,即菌盖形状为钟形时,一共有368种蘑菇,其中34种蘑菇有毒,因此当一种蘑菇的菌盖形状为钟形时,其有毒的概率为9.2%;当cap.shape=c,即菌盖形状为圆锥形时,一共有2种蘑菇,其中2种蘑菇都有毒,此时其有毒的概率为100%。因此,笔者认为这两个变量水平都与毒性有很高的相关性。将变量cap.shape的不同取值水平中毒蘑菇占比绘制成柱状图,如下所示:

barplot(z$danger,names.arg=z$cap.shape,main="cap.shape")
abline(h=0.75,col="red",lty=2,lwd=2)
abline(h=0.25,col="blue",lty=2,lwd=2)

由图可知,当蘑菇的菌盖形状为钟形(b)和漏斗形(s)时,该种蘑菇有超过75%的概率可食用;当其菌盖形状为圆锥形(c)时,有超过75%的概率有毒;而当其菌盖形状为平坦形(f)、球形(k)、凸形(x)时,不易判断该蘑菇是否有毒。由于变量cap.shape有一半的取值水平与毒性的相关性较低,它不是一个很好的分类特征。

#剩下的变量
posion_cor=function(data,column){
  column=as.character(column)
  x1=data.frame(class=data$class,column=column[-1]) %>%
    group_by(column) %>%
    summarise(danger=mean(class=="p"))
  p=barplot(x1$danger,names.arg=x1$column,main=column[1])
  p=p+abline(h=0.75,col="red",lty=2,lwd=2)
  return(p+abline(h=0.25,col="blue",lty=2,lwd=2))
}
#Error in plot.new() : figure margins too large,因此将图片保存后插入
par(mfrow=c(5,4))
train_data1[-(1:2)] %>%
  map(~posion_cor(train_data,.))

剩下19个变量各取值水平中毒蘑菇占比如图1所示。

变量各取值水平中毒蘑菇占比

变量各取值水平中毒蘑菇占比

可以看到,odor、gill.color、stalk.color.above.ring、stalk.color.below.ring和spore.print.color等变量与毒性的相关性较高。因此下面重点分析这几个变量。

与毒性相关性较高的变量分析

气味

train_data %>%
  count(odor,class) %>%
  ggplot()+
  geom_tile(aes(x=odor,y=class,fill=n))

可以看出,通过气味(odor)可以很好地分辨一种蘑菇是否有毒。当蘑菇有酚油味(c)、臭味(f)、霉味(m)、刺鼻的气味(p)、辛辣味(s)和鱼腥味(y)等奇怪味道时,该蘑菇很可能有毒。而当蘑菇的气味是杏仁味(a)、八角味(l)和没有味道(n)时,该蘑菇很可能无毒。

菌褶颜色

ggplot(data=train_data)+
  geom_bar(aes(x=gill.color,fill=class),color="black",position="fill")+
  scale_fill_manual(values=c("pink","lightblue"))+
  theme(panel.grid.major=element_blank())

由图可知,通过菌褶颜色(gill.color)也能较好地分辨一种蘑菇是否有毒。当蘑菇的菌褶为浅黄色(b)、灰色(g)、巧克力色(h)和绿色(r)时,该蘑菇有毒的概率更大。而当蘑菇的菌褶颜色为红色(e)、黑色(k)、棕色(n)、橙色(o)、紫色(u)和白色(w),该蘑菇无毒的概率更大。

菌柄环上和环下颜色

grid.newpage()
pushViewport(viewport(layout=grid.layout(2,1)))
vplayout=function(x,y){
  viewport(layout.pos.row=x,layout.pos.col=y)
}
p1=ggplot(data=train_data)+
  geom_bar(aes(x=stalk.color.above.ring,fill=class),color="black",position="dodge")+
  scale_fill_manual(values=c("pink","lightblue"))+
  theme(panel.grid.major=element_blank())
p2=ggplot(data=train_data)+
  geom_bar(aes(x=stalk.color.below.ring,fill=class),color="black",position="dodge")+
  scale_fill_manual(values=c("pink","lightblue"))+
  theme(panel.grid.major=element_blank())
print(p1,vp=vplayout(1,1))
print(p2,vp=vplayout(2,1))

变量stalk.color.above.ring和变量stalk.color.below.ring各取值水平下毒蘑菇占比十分相似,笔者尝试将两个变量结合为一个变量再进行分析。

train_data3=unite(train_data,stalk.color,stalk.color.above.ring,stalk.color.below.ring)
ggplot(data=train_data3)+
  geom_bar(aes(x=stalk.color,fill=class),color="black",position="dodge")+
  scale_fill_manual(values=c("pink","lightblue"))+
  theme(panel.grid.major=element_blank())

通过对比可以看出,结合后的变量stalk.color明显比原来的单个变量有更好的有毒蘑菇和可食用蘑菇的比例。当蘑菇的菌柄环上和环下颜色为浅黄、浅黄和棕色、浅黄和粉红、肉桂色、棕色和浅黄、棕色、棕色和粉红、粉红和浅黄、粉红和棕色、粉红、粉红和白色、白色和粉红、白色和黄色、黄色时,该蘑菇更可能有毒。而当蘑菇的菌柄环上和环下颜色为红色、红色和白色、灰色、灰色和粉红、灰色和白色、橙色、粉红和灰色、白色和红色、白色和灰色、白色和棕色、白色时,该蘑菇更可能可食用。

孢子印颜色

mosaic(~class+spore.print.color,data=train_data,highlighting="class",highlighting_fill=c("pink","lightblue"),direction=c("v","h"))

通过孢子印颜色(spore.print.color)也能较好地分辨一种蘑菇是否有毒。当蘑菇的孢子印为白色(w)、绿色(r)和巧克力色(h)时,该蘑菇有毒的概率更大。而当蘑菇的孢子印颜色为黄色(y)、紫色(u)、橙色(o)、棕色(n)、黑色(k)和浅黄色(b)时,该蘑菇无毒的概率更大。

模型预测

模型选择

对训练集构建决策树模型,选择所有20个自变量进行建模,定义此模型为model_1;将变量stalk.color.above.ring和变量stalk.color.below.ring结合为一个变量stalk.color(菌柄颜色)后,再次进行建模,定义此模型为model_2。

ct=rpart.control(xval=10,maxdepth=4,cp=0.00001)
mod1=rpart(class~.,data=train_data,method="class",parms=list(split="information"),control=ct)
mod2=rpart(class~.,data=train_data3,method="class",parms=list(split="information"),control=ct)

混淆矩阵

由于我们更关注的是将毒蘑菇分类正确的概率,因此本文将有毒蘑菇定义为正样本。所以下表中的accuracy指标代表决策树模型所有分类正确的蘑菇占所有蘑菇的比重;precision指标代表在所有模型预测为有毒的蘑菇中,模型预测对的比重;recall指标代表在实际有毒的蘑菇中,模型预测对的比重;f1-score指标则是precision和recall的调和平均值,其取值范围为0-1,越接近1代表模型分类效果越好。

model_1的混淆矩阵的指标:

pred1=predict(mod1,newdata=train_data,type="class")
accuracy1=mean(train_data$class==pred1)
precision1=sum(pred1==train_data$class & pred1=="p")/sum(pred1=="p")
recall1=sum(pred1==train_data$class & pred1=="p")/sum(train_data$class=="p")
f1_score1=2/(1/precision1+1/recall1)
table11=data.frame(index=c("accuracy","precision","recall","f1-score"),value=c(accuracy1,precision1,recall1,f1_score1))
knitr::kable(table11)
index value
accuracy 0.9978462
precision 1.0000000
recall 0.9955654
f1-score 0.9977778

model_2的混淆矩阵的指标:

pred2=predict(mod2,newdata=train_data3,type="class")
accuracy2=mean(train_data3$class==pred2)
precision2=sum(pred2==train_data3$class & pred2=="p")/sum(pred2=="p")
recall2=sum(pred2==train_data3$class & pred2=="p")/sum(train_data3$class=="p")
f1_score2=2/(1/precision2+1/recall2)
table22=data.frame(index=c("accuracy","precision","recall","f1-score"),value=c(accuracy2,precision2,recall2,f1_score2))
knitr::kable(table22)
index value
accuracy 0.9989231
precision 1.0000000
recall 0.9977827
f1-score 0.9988901

可以看到,模型model_1和model_2的precision指标均为1,即两个模型预测为有毒的蘑菇,实际也都有毒;而在其他指标上,model_2均比model_1表现更好,这主要是因为model_2的recall指标更高,即model_2将model_1错分为可食用蘑菇中的部分毒蘑菇分到了正确的类别中。综上所述,添加stalk.color变量后的模型model_2使得分类结果更加准确。

10折交叉验证

ct=rpart.control(xval=10,maxdepth=4,cp=0.00001)
folds=createFolds(y=train_data$class,k=10)
re_1={}
for(i in 1:10){
  trainda_1=train_data[-folds[[i]],]
  testda_1=train_data[folds[[i]],]
  mod_1=rpart(class~.,data=trainda_1,method="class",parms=list(split="information"),control=ct)
  pred=predict(mod_1,newdata=testda_1,type="class")
  re1=length(which(testda_1$class==pred))/length(testda_1$class)
  re_1=c(re_1,re1)
}
re_1=c(re_1,mean(re_1))
re_2={}
for(i in 1:10){
  trainda_2=train_data3[-folds[[i]],]
  testda_2=train_data3[folds[[i]],]
  mod_2=rpart(class~.,data=trainda_2,method="class",parms=list(split="information"),control=ct)
  pred_2=predict(mod_2,newdata=testda_2,type="class")
  re2=length(which(testda_2$class==pred_2))/length(testda_2$class)
  re_2=c(re_2,re2)
}
re_2=c(re_2,mean(re_2))
re_table=data.frame(fold=c(1:10,"mean"),model_1=re_1,model_2=re_2)
knitr::kable(re_table,caption="10折交叉验证结果")
10折交叉验证结果
fold model_1 model_2
1 0.9923195 0.9969278
2 0.9953775 0.9984592
3 0.9953846 1.0000000
4 1.0000000 1.0000000
5 0.9984592 0.9984592
6 1.0000000 1.0000000
7 0.9969231 0.9984615
8 0.9969231 0.9984615
9 0.9984615 1.0000000
10 0.9969278 0.9984639
mean 0.9970776 0.9989233

两个模型的10折交叉验证结果(表8)显示,使用model_2模型对蘑菇分类的平均准确率更高。

综合10折交叉验证结果和混淆矩阵的指标可知,选择model_2模型作为蘑菇分类的预测模型效果更好。

模型重要性检验

检测构建的决策树模型中各变量的重要性得分,结果如下:

importance=data.frame(variable.importance=mod2$variable.importance)
importance=data.frame(variable=row.names(importance),importance,row.names=NULL)
importance$variable=reorder(importance$variable,importance$variable.importance,FUN=sum)
ggplot(importance)+
  geom_bar(aes(x=variable,y=variable.importance,fill=variable),stat='identity',show.legend=FALSE)+
  theme(axis.title.y=element_blank(),panel.grid.major=element_blank())+
  coord_flip()

由上图可以看出,在本文构建的决策树模型中,odor、spore.print.color和gill.color等6个变量是最重要的,其中,创建的新变量stalk.color的重要性得分排名比较靠前,说明创建的这个新变量是十分有意义的。

模型预测结果

用建立的决策树模型model_2预测测试集中各种蘑菇是否有毒,输出混淆矩阵指标如下:

test_data1=unite(test_data,stalk.color,stalk.color.above.ring,stalk.color.below.ring)
pred=predict(mod2,newdata=test_data1,type="class")
accuracy=mean(test_data1$class==pred)
precision=sum(pred==test_data1$class & pred=="p")/sum(pred=="p")
recall=sum(pred==test_data1$class & pred=="p")/sum(test_data1$class=="p")
f1_score=2/(1/precision+1/recall)
table33=data.frame(index=c("accuracy","precision","recall","f1-score"),value=c(accuracy,precision,recall,f1_score))
knitr::kable(table33)
index value
accuracy 0.9993842
precision 1.0000000
recall 0.9986825
f1-score 0.9993408

可以看到,测试集中99.9%的蘑菇都分类正确,所有模型预测为有毒的蘑菇,实际上也都有毒。但是在实际有毒的蘑菇中,仍有0.2%被模型分类为可食用。总体来说,模型的预测效果非常好。

结论与不足

结论

rpart.plot(mod2,branch=1,type=1,extra=1,main="model_2 classification results")

根据上文的分析和上图可知,针对流传甚广的两大毒蘑菇分辨方法:“鲜艳的蘑菇都是有毒的,无毒蘑菇颜色朴素”和“可食用的无毒蘑菇多生长在清洁的草地或松树、栎树上,有毒蘑菇往往生长在阴暗、潮湿的肮脏地带”中所出现的颜色和环境特征。菌盖颜色未入选分类的重要变量;而入选的孢子印颜色中,孢子印为鲜艳的橙色和紫色蘑菇被分为无毒的概率更大;对于菌柄颜色而言,菌柄为鲜艳的红色和橙色的蘑菇也有更大的概率被认为无毒。至于环境特征,也未入选分类的重要变量。可以证明,这两种毒蘑菇分辨方法都是没有事实依据的。

而有效的分辨方法是综合利用蘑菇的气味、孢子印颜色、菌柄颜色和菌柄上环表面这四个特征来判断蘑菇是否有毒。但是也要认识到,此种分类方法仍然有很小的概率将毒蘑菇误分入可食用蘑菇中。实际生活中仍要慎用此方法。

不足与改进方向

虽然本文创建的新特征stalk.color在模型的重要性检测中得分较高,且对改善模型有一定效果,但是由于时间有限,还有很多特征有待挖掘。所以笔者接下来的一个改进方向是增加特征组合,产生一个特征候选集,通过处理筛选掉无用的特征后,增强模型的预测能力。