1 Introduction 

This R Markdown is my code in kaggle Competition Binary Prediction of Poisonous Mushrooms.

The source code in kaggle is : mushroom-rmarkdown-mlr3

The goal of this competition is to predict whether a mushroom is edible or poisonous based on its physical characteristics.

Evaluation(MMC)

\[ MCC=\frac{\left(TP\times TN-FP\times FN\right)}{\sqrt{\left(TP+FP\right)\left(TP+FN\right)\left(TN+FP\right)\left(TN+FN\right)}} \]

2 Setup

  1. Set the library path for loading the latest versions of catboost, lightgbm, and xgboost, where catboost and xgboost support GPU mode;

    • catboost:1.2.5(support GPU)
    • lightgbm:4.5.0
    • xgboost: 2.0.1.1(support GPU)
    • mlr3extralearners: 0.8.0
isKaggle = FALSE
set.seed(125432)
train_data_url=  "e:/data/kaggle/playground-series-s4e8/train.csv"
test_data_url = "e:/data/kaggle/playground-series-s4e8/test.csv"
if(isKaggle){
  .libPaths("/kaggle/input/r-extlib/kaggle/working/r-extlib")
  train_data_url= "/kaggle/input/playground-series-s4e8/train.csv"
  test_data_url = "/kaggle/input/playground-series-s4e8/test.csv"
}
library(dlookr)
library(mlr3extralearners)
library(tidyverse)
library(mlr3verse)

theme_set(cowplot::theme_minimal_grid())

3 Load Data

Load the training dataset and testing dataset, and rename each column of the dataset according to the R language specification.

train_data <- read_csv(train_data_url) 
test_data<- read_csv(test_data_url)
colnames(train_data) = gsub("-", "_", colnames(train_data))
colnames(test_data) = gsub("-", "_", colnames(test_data))

Conclusion

In the training and testing sets:

  • On numerical and logical variables, the range of data values, distribution of values, and missing values are relatively consistent

  • On character variables, the situation of missing values is relatively consistent, but there are differences in the range of values for some fields

  • Physical characteristics can be divided into: cap_*、gill_*、stem_*、veil_*,and others

3.1 Train Dataset

  • Data summary
skimr::skim(train_data)
Data summary
Name train_data
Number of rows 3116945
Number of columns 22
_______________________
Column type frequency:
character 16
logical 2
numeric 4
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
class 0 1.00 1 1 0 2 0
cap_shape 40 1.00 1 9 0 74 0
cap_surface 671023 0.78 1 20 0 83 0
cap_color 12 1.00 1 20 0 78 0
gill_attachment 523936 0.83 1 20 0 78 0
gill_spacing 1258435 0.60 1 11 0 48 0
gill_color 57 1.00 1 20 0 63 0
stem_root 2757023 0.12 1 17 0 38 0
stem_surface 1980861 0.36 1 20 0 60 0
stem_color 38 1.00 1 17 0 59 0
veil_type 2957493 0.05 1 7 0 22 0
veil_color 2740947 0.12 1 4 0 24 0
ring_type 128880 0.96 1 20 0 40 0
spore_print_color 2849682 0.09 1 10 0 32 0
habitat 45 1.00 1 20 0 52 0
season 0 1.00 1 1 0 4 0

Variable type: logical

skim_variable n_missing complete_rate mean count
does_bruise_or_bleed 117 1 0.18 FAL: 2569743, TRU: 547085
has_ring 143 1 0.24 FAL: 2368820, TRU: 747982

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
id 0 1 1558472.00 899784.66 0.00 779236.00 1558472.00 2337708.00 3116944.00 ▇▇▇▇▇
cap_diameter 4 1 6.31 4.66 0.03 3.32 5.75 8.24 80.67 ▇▁▁▁▁
stem_height 0 1 6.35 2.70 0.00 4.67 5.88 7.41 88.72 ▇▁▁▁▁
stem_width 0 1 11.15 8.10 0.00 4.97 9.65 15.63 102.90 ▇▁▁▁▁
  • Samples
head(train_data)

3.2 Test Dataset

  • Data summary
skimr::skim(test_data)
Data summary
Name test_data
Number of rows 2077964
Number of columns 21
_______________________
Column type frequency:
character 15
logical 2
numeric 4
________________________
Group variables None

Variable type: character

skim_variable n_missing complete_rate min max empty n_unique whitespace
cap_shape 31 1.00 1 12 0 62 0
cap_surface 446904 0.78 1 17 0 59 0
cap_color 13 1.00 1 88 0 57 0
gill_attachment 349821 0.83 1 17 0 66 0
gill_spacing 839595 0.60 1 9 0 35 0
gill_color 49 1.00 1 20 0 56 0
stem_root 1838012 0.12 1 5 0 31 0
stem_surface 1321488 0.36 1 20 0 54 0
stem_color 21 1.00 1 20 0 55 0
veil_type 1971545 0.05 1 2 0 15 0
veil_color 1826124 0.12 1 4 0 23 0
ring_type 86195 0.96 1 17 0 36 0
spore_print_color 1899617 0.09 1 10 0 33 0
habitat 25 1.00 1 17 0 39 0
season 0 1.00 1 1 0 4 0

Variable type: logical

skim_variable n_missing complete_rate mean count
does_bruise_or_bleed 75 1 0.18 FAL: 1713662, TRU: 364227
has_ring 113 1 0.24 FAL: 1578092, TRU: 499759

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
id 0 1 4155926.50 599856.68 3116945 3636435.75 4155926.50 4675417.25 5194908.00 ▇▇▇▇▇
cap_diameter 7 1 6.31 4.69 0 3.31 5.74 8.23 607.00 ▇▁▁▁▁
stem_height 1 1 6.35 2.70 0 4.67 5.88 7.41 57.29 ▇▁▁▁▁
stem_width 0 1 11.15 8.10 0 4.97 9.64 15.62 102.91 ▇▁▁▁▁
  • Samples
head(test_data)

4 EDA

4.1 Character Variable

View the values and distribution of character variables(Categorical values) with a proportion greater than 0.1% in the training and testing sets, as well as the differences between the two datasets

train_category = diagnose_category(train_data%>% select(-class)) %>% filter(ratio>0.001)
test_category = diagnose_category(test_data) %>% filter(ratio>0.001)

category_infos = train_category %>%
    left_join(test_category,by=c('variables','levels'),suffix = c(".train", ".test")) %>%
    mutate(ratio_diff = abs(ratio.train-ratio.test)) %>% 
    select(-N.train,-N.test)
category_infos

It can be seen that the values and distributions of the training and testing sets are basically consistent when the character type variable is greater than 0.1%.

  • Distribution Plot
library(patchwork)

gobal_tpr =  table(train_data$class)[2]/nrow(train_data)

plot_norminal_feature = function(field){
  # a statistical proportion greater than 0.1%
  min_ratio = 0.1
  df = train_data %>% 
    select(field,class) %>% 
    group_by_at(field) %>% 
    summarise(
      total_count = n()
    ) %>%  
    ungroup()  
  
   class_df = train_data %>% 
    select(field,class) %>% 
    group_by_at(c('class',field)) %>% 
    summarise(
      count = n()
    )
   
   join_df = df %>% left_join(class_df,by = field) %>% 
     mutate(
       class_ratio = count*100 / total_count,
       global_ratio = count *100 / nrow(train_data),
       level_ratio = total_count *100 / nrow(train_data)) %>% 
     rename(level = field) %>% 
     filter(level_ratio> min_ratio)
   
   
   
   top_level = length(unique(join_df$level))
   
   p1 = join_df %>% ggplot(aes(x=reorder(level,total_count),y = global_ratio,fill=class))+
     geom_col()+
     geom_text(aes(label = ifelse(class_ratio>70,round(class_ratio),'')),    hjust = -2,size =3 )+
     coord_flip() +
     xlab(paste0( "top ",top_level," level ( > 0.1% )"))+
     ylab(field)+
     theme(legend.position = "none")+
     geom_hline(yintercept = 0.1,color = 'blue', linetype = 'dashed') 
   
    p2 = join_df %>% ggplot()+
     geom_col(aes(x=reorder(level,-total_count),y = class_ratio,fill=class),position = 'fill')+
     geom_hline(yintercept = gobal_tpr,color = 'black', linetype = 'dashed' ) +
     # geom_hline(yintercept = 0.7,color = 'red', linetype = 'dashed',alpha =0.5) +
     # geom_hline(yintercept = 0.4,color = 'blue', linetype = 'dashed',alpha =0.5) +
     xlab(paste0(field, " - top ",top_level," level ")) 
    
    (p1+ p2  )
 
}

4.1.1 Cap

plot_norminal_feature ('cap_color')+
  plot_norminal_feature ('cap_shape') +
  plot_norminal_feature ('cap_surface')

attr top level(>0.1%) level count p >70% e >70%
cap_color n,y,w,g,e,o,p,r,u,b,k,i 12 e,o,r b
cap_shape x,f,s,b,o,p,c 7 b
cap_surface NA,t,s,y,h,g,d,k,e,i,w,l 12 k,i

4.1.2 Gill

plot_norminal_feature ('gill_attachment')+
  plot_norminal_feature ('gill_spacing')+
  plot_norminal_feature ('gill_color')

attr top level(>0.1%) level count p >70% e >70%
gill_attachment NA,a,d,x,e,s,p,f 8 p
gill_spacing NA,c,d,f 4
gill_color w,n,y,p,g,o,k,f,r,e,b,u 12 n,r b

4.1.3 Stem

plot_norminal_feature ('stem_root')+
  plot_norminal_feature ('stem_surface')+
  plot_norminal_feature ('stem_color')

attr top level(>0.1%) level count p >70% e >70%
stem_root NA,b,s,r,c 5 r,c
stem_surface NA,s,y,i,t,g,k,h 8 y,g,h
stem_color w,n,y,g,o,e,u,p,k,r,l,b 12 e,p,k,r b

4.1.4 Veil

plot_norminal_feature ('veil_type')|
  plot_norminal_feature ('veil_color')

attr top level(>0.1%) level count p >70% e >70%
veil_type NA,u 2
veil_color NA,w,y,n,u,k,e 7 n,u,k,e y

4.1.5 Other

(plot_norminal_feature ('ring_type')|plot_norminal_feature ('spore_print_color'))/ 
(plot_norminal_feature ('habitat')|plot_norminal_feature ('season'))

attr top level(>0.1%) level count p >70% e >70%
ring_type NA,f,e,z,l,r,p,g,m 9 z
spore_print_color NA,k,p,w,n,r,u,g 8 k,p,n,r,u g
habitat d,g,l,m,h,w,p,u 8 p w,u
season a,u,w,s 4

4.2 Numeric Variable

Find the numerical variable that contains outliers in train dataset

train_data  %>% select(find_outliers(.)) %>% plot_outlier()

finding intervals for numerical variable using recursive information gain ratio maximization

train_data %>% select(class,cap_diameter) %>% na.omit() %>%
  binning_rgr(class,cap_diameter,min_perc_bins=0.01,max_n_bins = 8)

train_data %>% select(class,stem_height) %>% na.omit() %>%
  binning_rgr(class,stem_height,min_perc_bins=0.01,max_n_bins = 8)

train_data %>% select(class,stem_width) %>% na.omit() %>%
  binning_rgr(class,stem_width, max_n_bins = 4)
# bin is:
# cap_diameter_bin = c(2.99,4.11,5.62,8.81,11.48,12.23,15.42) 
# stem_height_bin = c(4.4,8.37,9.73,10.48,11.12,11.78,13.06)
# stem_width_bin = c(0, 7.1, 9.75, 13.47, 19.41, 31.99, 103)
# plot numeric var
plot_num= function(data,var){
  bin = c(0, 17, 40, 60) 
  
  base_p = data %>% 
  select(class,{{var}}) %>% 
  ggplot(aes(x = {{var}},fill = class))

  dou_p = base_p + geom_boxplot()
  
  p0 = base_p + 
    geom_density(alpha = 0.5)
  
  p_all = dou_p / p0
  
  for(index in 1:(length(bin)-1)){
    min = bin[index]
    max = bin[index+1]

    p1 = base_p +
    geom_density(alpha = 0.5) +
    scale_x_continuous(limits = c(min,max), breaks = seq(0,100,2),labels = seq(0,100,2))+
    theme(legend.position = "none")

    p2= base_p +
      geom_histogram(position = "fill", color='black',bins = 20)+
      scale_x_continuous(limits = c(min,max), breaks = seq(0,100,2),labels = seq(0,100,2)) +
       geom_hline(yintercept = gobal_tpr,color = 'black', linetype = 'dashed')

    p_all = p_all / (p1|p2)
  }
  p_all
}

4.2.1 cap_diameter

train_data %>% plot_num(cap_diameter)

column rang p e p/e die or happy
cap_diameter 0~1 67497 19574 3.45 💀💀💀
1~2 186542 89967 2.07 💀💀
2~3 200687 63299 3.17 💀💀💀
3~5 434059 304029 1.43 💀
5~17 789721 911028 0.87 😄
17~40 26650 10974 2.43 💀💀
>40 237 12676 0.19 😄😄😄

4.2.2 stem_width

train_data %>% plot_num(stem_width)

column rang p e p/e die or happy
stem_width 0~8 925284 439700 2.1 💀💀
8~17 529829 581377 0.91 😄
17~27 148067 349131 0.42 😄😄
27~50 101386 39182 2.59 💀💀💀
50~60 783 1219 0.64 😄
>60 46 931 0.05 😄😄😄😄😄……….😄

4.2.3 stem_height

train_data %>% plot_num(stem_height)

column rang p e p/e die or happy
stem_height 0~2 24627 119 207.95 💀💀💀💀💀……….💀
2~2.5 10718 10833 0.99 😄
2.5~4 313967 114937 2.73 💀💀💀
4~9 1109790 1135538 0.98 😄
9~18 242477 137011 1.77 💀💀
18~40 3813 13081 0.29 😄😄😄
>40 4 30 0.13 😄😄😄

4.3 Correlations

library(corrplot)
train_data %>% 
  mutate_if(is.character, ~ replace_na(., "UNKN")) %>% 
  na.omit() %>% 
  mutate_if(is.character,as.factor) %>% 
  mutate_if(is.logical,as.factor) %>% 
  mutate_if(is.factor,as.numeric) %>% 
  cor()%>% corrplot() 

  • stem_width 、cap_diameter:+++

  • ring_type、hash_ring:++

  • stem_height、cap_diameter:+

  • stem_height、veil_type:-

  • hash_ring、veil_type -

p1 = train_data %>% 
  select(stem_width,cap_diameter,class) %>% 
  sample_frac(size = 0.001) %>%
  ggplot(aes(x = stem_width,y = cap_diameter,color=class)) +
  geom_point() + 
  geom_smooth(method = 'gam',span = 0.3)+
  lims(x=c(0,40),y=c(0,40))+
  labs(title = 'stem_width - cap_diameter corrplot')+
  theme(legend.position = 'none')

p2 = train_data %>% 
  select(stem_height,cap_diameter,class) %>% 
  sample_frac(size = 0.001) %>%
  ggplot(aes(x = stem_height,y = cap_diameter,color=class)) +
  geom_point() + 
  geom_smooth(method = 'gam',span = 0.3)+
  lims(x=c(0,40),y=c(0,40))+
  labs(title = 'stem_height - cap_diameter corrplot')

p1|p2

5 Data cleaning

5.1 outltier

  • Character variable: The value range should be a single letter between a and z. Values outside of this range are marked as outliers, and a new column is_outlier_char is added to identify them
  • Numeric variable: The value range should be>0. Values outside of this range are marked as outliers, and a new column is_outlier_num is added to identify them
  • Logical variable: No outlier situation, no need to handle

5.1.1 mark outltier

mark_outlier =function (data){
    # all character variables in the dataset
    # If the value  length in the variable is greater than 1, it will be marked as an outlier
    outlier_char_ids = (data %>% filter(if_any(where(is.character),  ~nchar(.x) > 1)))$id
    outlier_num_ids = (data %>% filter(if_any(where(is.numeric),  ~.x <= 0)))$id
    data %>% mutate(
        is_outlier_char = if_else(id %in% outlier_char_ids,  1, 0), 
        is_outlier_num = if_else(id %in% outlier_num_ids,1, 0))
}

mark train dataset and test dataset

train_data = train_data %>%  mark_outlier()
test_data = test_data %>%  mark_outlier()
train_data %>% filter(is_outlier_char == 1) 
test_data %>% filter(is_outlier_char == 1) 

5.1.2 fill outliter

find_outlier_category =function(data,show = F){
    # 1.find outliter
    # 2.calculate the values and frequencies of each character variable
    # 3.filter nchar(levels)>1 
    # 4.sum and order
    outlier_data = data %>%  filter(is_outlier_char==1)
    outlier_category =
      outlier_data %>%
      diagnose_category(top = 1000) %>% 
      filter(nchar(levels)>1) %>% 
      group_by(levels) %>% 
      summarise(total_freq = sum(freq)) %>% 
      arrange(desc(levels))
    if(show){
        # print outliter  
        print(outlier_category,n=50) 
    }
    outlier_category
}

before fill outliter

toString(find_outlier_category(train_data )$levels)
## [1] "veil-type, veil-color, stem-surface, stem-root, spore-print-color, spore-color, spacing, sp, season, ring-type, p p, is y, is w, is s, is p, is n, is k, is h, is f, is a, is None, has-ring, has h, has f, has d, habitat, e y, e n, does-bruise-or-bleed, does w, does t, does s, does n, does l, does h, does f, does None, does, class, cap-surface, cap-diameter, b f, 9.88, 9.55, 9.46, 9.28, 9.22, 9.13, 9.02, 9.01, 9 None, 8.96, 8.83, 8.79, 8.67, 8.57, 8.49, 8.47, 8.37, 8.32, 8.3, 8.29, 8.25, 8.1, 8.09, 8.06, 8.01, 7.99, 7.92, 7.86, 7.84, 7.72, 7.7, 7.6, 7.59, 7.45, 7.43, 7.41, 7.37, 7.33, 7.31, 7.23, 7.21, 7.15, 7.14, 7.09, 7 x, 6.9, 6.76, 6.75, 6.74, 6.67, 6.63, 6.59, 6.58, 6.57, 6.53, 6.49, 6.45, 6.44, 6.41, 6.4, 6.36, 6.32, 6.31, 6.21, 6.2, 6.19, 6.11, 6.09, 6 x, 55.13, 54.78, 50.44, 5.97, 5.94, 5.93, 5.91, 5.81, 5.73, 5.7, 5.62, 5.59, 5.56, 5.55, 5.51, 5.48, 5.42, 5.41, 5.35, 5.25, 5.22, 5.15, 5.07, 5.01, 5 f, 49.46, 49.21, 41.91, 4.98, 4.97, 4.93, 4.89, 4.8, 4.77, 4.75, 4.74, 4.66, 4.64, 4.62, 4.58, 4.49, 4.41, 4.34, 4.33, 4.3, 4.24, 4.22, 4.21, 4.09, 4.04, 4.01, 4. n, 39.51, 33.52, 32.54, 3.98, 3.95, 3.92, 3.91, 3.89, 3.81, 3.71, 3.68, 3.64, 3.63, 3.62, 3.61, 3.6, 3.57, 3.56, 3.55, 3.53, 3.52, 3.49, 3.45, 3.4, 3.39, 3.37, 3.34, 3.33, 3.32, 3.25, 3.24, 3.23, 3.13, 3.12, 3.11, 3.08, 3.06, 3.04, 3 x, 28.7, 28.15, 26.89, 26.4, 25.98, 25.83, 24.75, 24.38, 24.16, 24.12, 23.6, 23.59, 23.18, 22.6, 22.38, 21.56, 21.53, 21.11, 20.62, 20.6, 20.44, 20.25, 20.07, 20.02, 20.01, 20.0, 20, 2.94, 2.92, 2.9, 2.87, 2.85, 2.82, 2.81, 2.79, 2.78, 2.77, 2.75, 2.7, 2.69, 2.68, 2.67, 2.63, 2.62, 2.57, 2.54, 2.51, 2.49, 2.44, 2.41, 2.25, 2.11, 2.05, 19.65, 19.35, 19.29, 19.06, 19.04, 18.35, 18.21, 18.12, 18.06, 18.03, 17.94, 17.93, 17.46, 17.45, 17.44, 17.38, 17.19, 17.1, 17, 16.88, 16.46, 16.41, 16.39, 16.33, 16.27, 15.94, 15.69, 15.49, 15, 14.04, 14, 13.94, 13.15, 13.1, 13.03, 13, 12.92, 12.89, 12.79, 12.62, 12.27, 12.2, 12.04, 11.92, 11.78, 11.62, 11.26, 11.13, 11.12, 11, 10.93, 10.87, 10.85, 10.83, 10.56, 10.48, 10.46, 10.34, 10.23, 10.21, 10.13, 10.1, 10.07, 10 None, 1.94, 1.91, 1.88, 1.75, 1.66, 1.6, 1.59, 1.51, 1.48, 1.43, 1.42, 1.41, 1.37, 1.36, 1.32, 1.14, 1.08, 1.03, 0.92, 0.88, 0.87, 0.85, 0.82, 0.73, 0.0"
toString(find_outlier_category(test_data )$levels)
## [1] "veil-type, veil-color, stem-root, spore-print-color, spacing, sp, season, ring-type, p f, is w, is p, is f, is None, has-ring, has g, has f, habitat, e s, does-bruise-or-bleed, does f, does c, does None, class, cap-diameter, cap---------------------------------------------------------------------------------root, 9.98, 9.69, 9.53, 9.49, 9.41, 9.33, 9.19, 9.01, 8.95, 8.82, 8.73, 8.53, 8.34, 8.33, 8.3, 8.21, 8.12, 8.1, 8.09, 8.04, 7.96, 7.81, 7.78, 7.71, 7.41, 7.35, 7.21, 7.18, 7.07, 7.01, 6.99, 6.93, 6.75, 6.74, 6.7, 6.67, 6.59, 6.58, 6.52, 6.51, 6.47, 6.35, 6.3, 6.28, 6.24, 6.18, 6.15, 6.14, 6.12, 6.11, 6.06, 6.04, 51.63, 5.98, 5.93, 5.92, 5.84, 5.83, 5.61, 5.59, 5.57, 5.5, 5.38, 5.35, 5.16, 5.1, 5.05, 5.01, 5 f, 4.96, 4.92, 4.91, 4.87, 4.86, 4.8, 4.78, 4.76, 4.75, 4.61, 4.58, 4.55, 4.54, 4.51, 4.5, 4.41, 4.18, 4.11, 4.02, 35.38, 32.63, 32.6, 3.94, 3.8, 3.73, 3.7, 3.66, 3.65, 3.59, 3.49, 3.48, 3.47, 3.42, 3.38, 3.34, 3.32, 3.24, 3.23, 3.19, 3.15, 3.05, 3.02, 3.0, 29.82, 27.48, 26.48, 25.92, 24.74, 24.73, 23.96, 23.73, 23.18, 22.33, 21.87, 21.38, 20.64, 2.98, 2.97, 2.96, 2.95, 2.93, 2.92, 2.86, 2.84, 2.83, 2.82, 2.75, 2.73, 2.7, 2.6, 2.53, 2.52, 2.51, 2.47, 2.44, 2.17, 2.02, 19.85, 19.76, 19.46, 19.18, 18.89, 18.5, 18.49, 18.35, 18.29, 18.05, 18, 17.98, 17.97, 17.89, 17.77, 17.72, 17.49, 17.26, 17.16, 17.11, 17.01, 16.48, 15.55, 15.52, 15.51, 15.25, 14.18, 13.66, 13.46, 13.42, 13.09, 13.01, 12.99, 12.91, 12.87, 12.63, 12.3, 12.22, 12.15, 11.96, 11.8, 11.53, 11.43, 11.31, 11.0, 11, 10.93, 10.83, 10.62, 10.6, 10.56, 10.36, 10.34, 10.14, 10.09, 1.95, 1.83, 1.75, 1.68, 1.64, 1.62, 1.61, 1.58, 1.56, 1.55, 1.53, 1.48, 1.46, 1.32, 1.26, 1.25, 0.97, 0.95, 0.94, 0.93, 0.91, 0.88, 0.87, 0.74, 0.73"
fill_outlier_datas = function(data){
   outlier_datas = data %>% filter(is_outlier_char == 1)
   outlier_data_clean = outlier_datas %>% 
      mutate_if(is.character,
      # 异常点替换
      ~ case_when(
        .x %in% c('veil-type','veil-color') ~'v',
        .x %in% c('stem-root','spore-print-color','spacing','sp','season','is s','does s','stem-surface','spore-color') ~'s',
        .x %in% c('ring-type') ~'r',
        .x %in% c('is f','does f','has f','5 f') ~'f',
        .x %in% c('is None','does None','10 None','9 None') ~ NA,
        .x %in% c('p f','is p','p p') ~ 'p',
        .x %in% c('is w','does w') ~ 'w',
        .x %in% c('is t','does t') ~ 't',
        .x %in% c('is l','does l') ~ 'l',
        .x %in% c('is y') ~ 'y',
        .x %in% c('7 x','6 x','5 x','4 x','3 x','2 x','1 x') ~ 'x',
        .x %in% c('is n','does n','4. n') ~ 'n',
        .x %in% c('is k') ~ 'k',
        .x %in% c('is a') ~ 'a',
        .x %in% c('b f') ~ 'b',
        .x %in% c('has-ring','habitat','is h','has h','does h') ~ 'h',
        .x %in% c('has g') ~ 'g',
        .x %in% c('e s','e y','e n') ~ 'e',
        .x %in% c('does-bruise-or-bleed','has d','does') ~ 'd',
        .x %in% c('does c','cap-surface','cap-diameter','cap---------------------------------------------------------------------------------root') ~ 'c',
        .x %in% c('class') ~ 'p',
         # 对于大量的字符类型填充了数值类型的数据,将标记为采用当前变量的众数替换
         grepl("^[+-]?\\d+(?:\\.\\d+)?$", .x) ~  'mode',
        .default = .x
        )
      ) %>% 
      mutate(
        # 用变量的众数替换标记为mode的数据
        cap_color = if_else(cap_color == 'mode','n',cap_color),
        cap_shape = if_else(cap_shape == 'mode','x',cap_shape),
        cap_surface = if_else(cap_surface == 'mode','t',cap_surface),
        gill_attachment = if_else(gill_attachment == 'mode','a',gill_attachment),
        gill_spacing = if_else(gill_spacing == 'mode','c',gill_spacing),
        gill_color = if_else(gill_color == 'mode','w',gill_color),
        stem_root = if_else(stem_root == 'mode','b',stem_root),
        stem_surface = if_else(stem_surface == 'mode','s',stem_surface),
        stem_color = if_else(stem_color == 'mode','w',stem_color),
        veil_type = if_else(veil_type == 'mode','u',veil_type),
        veil_color = if_else(veil_color == 'mode','w',veil_color),
        ring_type = if_else(ring_type == 'mode','f',ring_type),
        spore_print_color = if_else(spore_print_color == 'mode','k',spore_print_color),
        habitat = if_else(habitat == 'mode','d',habitat)
      )
 # 在原始数据集中过滤掉异常值数据,然后正常数据集和处理后的异常数据重新合并,并按照id重新排序
 data %>% 
    filter(! id %in% outlier_data_clean$id ) %>% 
    bind_rows(outlier_data_clean) %>% 
    arrange(id)
}
train_data = fill_outlier_datas(train_data) 
test_data = fill_outlier_datas(test_data) 

after fill outliter

find_outlier_category(train_data )$levels
find_outlier_category(test_data )$levels

5.2 fill na

  • 👉 Character variable: filled with ‘UNKN’
  • 👉 Numeric variable:filled with 0
  • 👉 Logical variable:filled with FALSE
fill_na = function (data){
    data %>%
        mutate(across(where(is.character),~ if_else(is.na(.x),'UNKN',.x)))%>%
        mutate(across(where(is.numeric),~ if_else(is.na(.x),0,.x)))%>%
        mutate(across(where(is.logical),~ if_else(is.na(.x),FALSE,.x))) 
}
train_data = train_data %>%  fill_na()  
test_data = test_data %>%  fill_na()

6 Modeling

6.1 pipeline

# character to factor
poca = po("colapply", applicator = as.factor) 
poca$param_set$values$affect_columns = selector_union(selector_type("character"),selector_type('logical'))

# collapse factors
po_collapsefactors = po("collapsefactors", no_collapse_above_prevalence = 0.001)

# encode
po_encode = 
po("encode", method = "one-hot",
   affect_columns = selector_cardinality_greater_than(5),
   id = "low_card_enc")

# convert factor directly to a numerical value
po_factor_to_num = po("colapply", applicator = as.numeric,id="factor_to_num",affect_columns = selector_type("factor"))


po_scale = po("scale", affect_columns =  selector_name(c("stem_width","stem_height","cap_diameter")))

po_log10 = po("mutate", affect_columns =  selector_name(c("stem_width","stem_height","cap_diameter")))
po_log10$param_set$values$mutation = list(
  stem_width = ~ log10(stem_width + 1),
  stem_height = ~ log10(stem_height + 1),
  cap_diameter = ~ log10(cap_diameter + 1)
)

po_predata =
    poca %>>% po("fixfactors") %>>% po_collapsefactors %>>%
    po_encode  %>>% po_factor_to_num %>>%
    po_log10 %>>% po_scale 

6.2 task

task = as_task_classif(train_data %>% select(-id) ,target = 'class', positive = 'p',id='mushrooms')
task$col_roles$stratum = task$target_names

6.3 lightgbm

lrn_lightgbm =lrn("classif.lightgbm",predict_type = "prob"
                  ,num_iterations = 2500
                  ,objective= 'binary'
                  ,max_bin= 1024
                  ,learning_rate= 0.08
                  ,lambda_l1= 4.2
                  ,lambda_l2=5e-05
                  ,device_type = 'cpu'
                  ,force_col_wise =T
                  ,num_threads =4)

lrn_lightgbm = as_learner(poca %>>% po("fixfactors") %>>% po_collapsefactors   %>>% lrn_lightgbm)
lrn_lightgbm$id="lrn_lightgbm"

benchmark(over 40 minutes😴)

lrn_lightgbm$predict_sets = c("train","test")
bmr = benchmark(
  benchmark_grid(
    tasks =  task,
    learners = list( lrn_lightgbm),
    resamplings = rsmp("cv",folds =4)
  )
)

bmr$score(msr("classif.mcc"))
bmr$aggregate(msr("classif.mcc",predict_sets='train'))
bmr$aggregate(msr("classif.mcc"))

6.4 catboost

lrn_catboost = lrn("classif.catboost"
                   ,predict_type = "prob" 
                   ,iterations = 1000
                   ,learning_rate = 0.1
                   ,logging_level  = 'Info'
                   ,metric_period = 100
                   ,grow_policy= 'Lossguide'
                   # ,task_type= 'GPU'
                   ,thread_count =4)

lrn_catboost = as_learner(po_predata   %>>% lrn_catboost)
lrn_catboost$id="lrn_catboost"

benchmark(over 20 minutes😞)

use GPU over 20 minutes, no GPU maybe over 200 minutes !!!

lrn_catboost$predict_sets = c("train","test")
bmr = benchmark(
  benchmark_grid(
    tasks =  task,
    learners = list( lrn_catboost),
    resamplings = rsmp("cv",folds =4)
  )
)

bmr$score(msr("classif.mcc"))
bmr$aggregate(msr("classif.mcc",predict_sets='train'))
bmr$aggregate(msr("classif.mcc"))

6.5 xgboost

lrn_xgboost = lrn("classif.xgboost"
                  ,predict_type = "prob"  
                  ,booster = 'gbtree'
                  ,tree_method = 'hist'
                  ,eta = 1
                  ,nrounds = 100
                  # ,device = 'cuda'
                  ,num_parallel_tree = 100
                  ,colsample_bynode= 0.7
                  ,subsample= 0.8
                  ,colsample_bytree=0.5
                  ,lambda= 8
                  ,max_depth= 8
                  ,nthread =4)
lrn_xgboost = as_learner(po_predata   %>>% lrn_xgboost)
lrn_xgboost$id="lrn_xgboost"

benchmark(over 30 minutes😞)

use GPU over 30 minutes, no GPU maybe over 300 minutes !!!

lrn_xgboost$predict_sets = c("train","test")
bmr = benchmark(
  benchmark_grid(
    tasks =  task,
    learners = list( lrn_xgboost),
    resamplings = rsmp("cv",folds =4)
  )
)

bmr$score(msr("classif.mcc"))
bmr$aggregate(msrs(c("classif.mcc","time_train"),predict_sets='train'))
bmr$aggregate(msr("classif.mcc"))
print("learner init finished ")
## [1] "learner init finished "

6.6 stacking

lrn_super = lrn("classif.glmnet")
lrn_stacking = as_learner(po_predata %>>% ppl("stacking",
                                   base_learners = list(lrn_lightgbm,lrn_catboost,lrn_xgboost),
                                   super_learner = lrn_super, 
                                   use_features = FALSE, 
                                   folds = 4))
 
lrn_stacking$id = "lrn_stacking"
print("lrn_stacking init finished ")
## [1] "lrn_stacking init finished "

7 Tune

7.1 subsample frac

  • 👉 Purpose :Sampling the training set data at different ratios to evaluate the performance differences under different conditions; I hope to find a compromise between high performance and short training time (my computer resources are limited, training with full data is too time-consuming)😂
#Sample partial data for quick validation of subsequent hyperparameter tuning, benchmark testing, and other analytical work
#Because my computer does not have GPU acceleration, I am only setting the tuning parameter range to 0.001-0.01 (approximately 3000 to 30000 data points for tuning)
#In the Kaggle environment, with GPU acceleration, full data can be used for tuning
po_subsample = po("subsample",frac= to_tune(0.001,0.01,logscale = T))
lrn_catboost_tune = lrn("classif.catboost")
lrn_catboost_tune = as_learner(po_subsample %>>% poca %>>% po("fixfactors") %>>% po_collapsefactors %>>% lrn_catboost_tune)
lrn_catboost_tune$id = "lrn_catboost_tune"
lrn_catboost_tune$predict_sets = c("train","test")
# 查看一下lrn_catboost_tune 流程图
lrn_catboost_tune$graph$plot(horizontal = T,html = T)
future::plan("multisession", workers = 4)
instance = tune(
  tuner = tnr("grid_search", resolution=4,batch_size = 4),
  task = task,
  learner = lrn_catboost_tune,
  #resampling = rsmp("cv", folds = 4 ),
  resampling = rsmp("holdout", ratio = 0.8 ),
  measures = msrs("classif.mcc")
)
## INFO  [18:41:41.276] [bbotk] Starting to optimize 1 parameter(s) with '<OptimizerBatchGridSearch>' and '<TerminatorNone>'
## INFO  [18:41:41.324] [bbotk] Evaluating 4 configuration(s)
## INFO  [18:41:41.355] [mlr3] Running benchmark with 4 resampling iterations
## INFO  [18:41:47.568] [mlr3] Applying learner 'lrn_catboost_tune' on task 'mushrooms' (iter 1/1)
## INFO  [18:41:54.526] [mlr3] Applying learner 'lrn_catboost_tune' on task 'mushrooms' (iter 1/1)
## INFO  [18:42:02.138] [mlr3] Applying learner 'lrn_catboost_tune' on task 'mushrooms' (iter 1/1)
## INFO  [18:42:12.233] [mlr3] Applying learner 'lrn_catboost_tune' on task 'mushrooms' (iter 1/1)
## INFO  [18:45:58.872] [mlr3] Finished benchmark
## INFO  [18:45:59.336] [bbotk] Result of batch 1:
## INFO  [18:45:59.342] [bbotk]  subsample.frac classif.mcc warnings errors runtime_learners
## INFO  [18:45:59.342] [bbotk]       -6.907755   0.9625063        0      0            63.87
## INFO  [18:45:59.342] [bbotk]       -6.140227   0.9755321        0      0            79.06
## INFO  [18:45:59.342] [bbotk]       -5.372699   0.9781019        0      0           121.09
## INFO  [18:45:59.342] [bbotk]       -4.605170   0.9799291        0      0           184.96
## INFO  [18:45:59.342] [bbotk]                                 uhash
## INFO  [18:45:59.342] [bbotk]  b5e9af13-7907-4cc6-a9cf-2d74f4b7dd3b
## INFO  [18:45:59.342] [bbotk]  dcbb2acc-5345-48ae-8f87-d0003afa0b41
## INFO  [18:45:59.342] [bbotk]  2119efc0-daa0-4b18-b1a8-36288264450a
## INFO  [18:45:59.342] [bbotk]  7b809b01-0446-4cc4-879d-d851a64a0d84
## INFO  [18:45:59.368] [bbotk] Finished optimizing after 4 evaluation(s)
## INFO  [18:45:59.369] [bbotk] Result:
## INFO  [18:45:59.372] [bbotk]  subsample.frac learner_param_vals  x_domain classif.mcc
## INFO  [18:45:59.372] [bbotk]           <num>             <list>    <list>       <num>
## INFO  [18:45:59.372] [bbotk]        -4.60517         <list[14]> <list[1]>   0.9799291
instance$result

the scores of the training and testing sets

instance$archive$benchmark_result$aggregate(msr("classif.mcc",predict_sets = 'train'))
instance$archive$benchmark_result$aggregate(msr("classif.mcc",predict_sets = 'test'))

7.2 Hyperparameter

Very time-consuming, use with caution

  • On computers without GPU acceleration, it is recommended to train with only 1% of the data.

  • If using GPU acceleration in a Kaggle environment, you can skip this code and use full data for tuning

task = po("subsample",frac = 0.01)$train(list(task))[[1]]
lrn_lightgbm_tune = lrn_lightgbm
ps_params = ps(
    num_iterations = p_int(500,1500),
    classif.lightgbm.learning_rate =  p_dbl(0.001,0.1,logscale = TRUE)
)
instance = tune(
  tuner = tnr("grid_search", resolution=10,batch_size = 4),
  task = task_sub,
  learner = lrn_lightgbm_tune,
  resampling = rsmp("cv", folds = 4 ),
  measures = msrs("classif.mcc"),
  search_space = ps_params
  #term_time = 3600*3.5
)
instance$result
instance_test =
  as.data.table(instance$archive)[,.(classif.mcc,
                                    learning_rate = x_domain_classif.lightgbm.learning_rate,
                                    batch_nr)]
instance_test[order(classif.mcc)]
instance_train =
  instance$archive$benchmark_result$aggregate(msr("classif.mcc",predict_sets ='train'))
instance_train %>% 
  ggplot(aes(x= log(learning_rate), color=batch_nr)) +
  geom_line(aes(y = classif.mcc),size = 2)+
  geom_point(aes(y = classif.mcc),size = 4)+
 
 
  geom_line(aes(y = train.mcc),size = 2)+
  geom_point(aes(y = train.mcc),size = 4)+
 
  scale_x_continuous(n.breaks = 10)

8 Train

print("learner train started ")
## [1] "learner train started "

use stacking model (cv 4 folds) will over kaggle max memory, so use this method to train . 😂

folds = 10
rr = rsmp("cv",folds =folds )
rr$instantiate(task)
index = 1

print(sprintf("round  %s start",index))
## [1] "round  1 start"
train_ids = rr$train_set(index)
val_ids = rr$test_set(index)

    
print("lightgbm train start")
## [1] "lightgbm train start"
lrn_lightgbm$train(task,row_ids = train_ids)

print("lrn_xgboost train start")
## [1] "lrn_xgboost train start"
lrn_xgboost$train(task,row_ids = train_ids)

print("lrn_catboost train start")
## [1] "lrn_catboost train start"
lrn_catboost$train(task,row_ids = train_ids)
## 0:   learn: 0.6291315    total: 161ms    remaining: 2m 41s
## 100: learn: 0.0474829    total: 1.65s    remaining: 14.6s
## 200: learn: 0.0296160    total: 3.12s    remaining: 12.4s
## 300: learn: 0.0216506    total: 4.59s    remaining: 10.7s
## 400: learn: 0.0159634    total: 6.15s    remaining: 9.19s
## 500: learn: 0.0122274    total: 7.61s    remaining: 7.58s
## 600: learn: 0.0109665    total: 9s   remaining: 5.97s
## 700: learn: 0.0087057    total: 10.4s    remaining: 4.45s
## 800: learn: 0.0070415    total: 11.9s    remaining: 2.95s
## 900: learn: 0.0061749    total: 13.3s    remaining: 1.46s
## 999: learn: 0.0057371    total: 14.6s    remaining: 0us
print("lightgbm predict start")
## [1] "lightgbm predict start"
pred_lightgbm = lrn_lightgbm$predict(task,row_ids = val_ids)
p_lightgbm = as.data.table(pred_lightgbm) 

print("xgboost predict start")
## [1] "xgboost predict start"
pred_xgboost = lrn_xgboost$predict(task,row_ids = val_ids)
p_xgboost = as.data.table(pred_xgboost)

print("catboost predict start")
## [1] "catboost predict start"
pred_catboost = lrn_catboost$predict(task,row_ids = val_ids)
p_catboost = as.data.table(pred_catboost)

print("learner use time")
## [1] "learner use time"
lrn_lightgbm$timings
##   train predict 
##    5.82    0.83
lrn_xgboost$timings
##   train predict 
##   58.94    2.11
lrn_catboost$timings
##   train predict 
##   16.58    2.72
print("combine stacking data")
## [1] "combine stacking data"
stacking_data = tibble(
    class = p_lightgbm$truth,
    lightgbm_prob = p_lightgbm$prob.p,
    xgboost_prob = p_xgboost$prob.p,
    catboost_prob = p_catboost$prob.p
)

task_stack = as_task_classif(stacking_data,target = 'class',positive = 'p')

lrn_glmnet = lrn("classif.glmnet")

lrn_glmnet$train(task_stack)

8.1 lightgbm

pred_lightgbm$confusion
##         truth
## response    p    e
##        p 1685   21
##        e   23 1389
pred_lightgbm$score(msr("classif.mcc"))
## classif.mcc 
##   0.9715209

8.2 catboost

pred_catboost$confusion
##         truth
## response    p    e
##        p 1685   17
##        e   23 1393
pred_catboost$score(msr("classif.mcc"))
## classif.mcc 
##    0.974123

8.3 xgboost

pred_xgboost$confusion
##         truth
## response    p    e
##        p 1689   16
##        e   19 1394
pred_xgboost$score(msr("classif.mcc"))
## classif.mcc 
##   0.9773488

9 Predict

test_ids =  test_data$id
test_data = test_data %>% select(-id)

lightgbm_preds <- lrn_lightgbm$predict_newdata(test_data)
xgboost_preds <- lrn_xgboost$predict_newdata(test_data)
catboost_preds <- lrn_catboost$predict_newdata(test_data)
print("learner predict newdata finished ")

test_stacking_data = tibble(
    lightgbm_prob = as.data.table(lightgbm_preds)$prob.p,
    xgboost_prob = as.data.table(xgboost_preds)$prob.p,
    catboost_prob = as.data.table(catboost_preds)$prob.p
)

preds = lrn_glmnet$predict_newdata(test_stacking_data)

result = tibble(
  id = test_ids,
  class = as.data.table(preds)$response
)

10 Submission

result  %>% write_csv("submission.csv")
print(" save  submission.csv  success ")

11 Model Interpretation

11.1 feature importance

feature_importances =
  as.data.table(lrn_lightgbm$base_learner()$importance(),keep.rownames = T) %>% 
    mutate(learner=  'lightgbm',rank = row_number())  %>% 
  bind_rows(
    as.data.table(lrn_catboost$base_learner()$importance(),keep.rownames = T) %>%
      mutate(learner=  'catboost',rank = row_number(),V2 = V2 /100)) %>% 
  bind_rows(
    as.data.table(lrn_xgboost$base_learner()$importance(),keep.rownames = T) %>% 
      mutate(learner=  'xgboost',rank = row_number())) %>% 
  rename(feature = 'V1',score= 'V2')
p1 = feature_importances %>% filter(rank<20 & learner == 'lightgbm'  ) %>% 
  ggplot(aes(x=reorder(feature,-rank),y=score))+
  geom_col() + 
  xlab("lightgbm")+
  coord_flip() 
p1

p2 = feature_importances %>% filter(rank<20 & learner == 'catboost'  ) %>% 
  ggplot(aes(x=reorder(feature,-rank),y=score))+
  geom_col() + 
  xlab("catboost")+
  coord_flip() 
p2

p3 = feature_importances %>% filter(rank<20 & learner == 'xgboost'  ) %>% 
  ggplot(aes(x=reorder(feature,-rank),y=score))+
  geom_col() + 
  xlab("xgboost")+
  coord_flip() 

p3

11.2 Local EMA

library(DALEX)
library(DALEXtra)
 
exp_data = task$data(rows = val_ids)
exp_data_x = exp_data [,-1]
exp_data_y = as.numeric(exp_data$class == "p")

lightgbm_exp = DALEXtra::explain_mlr3(lrn_lightgbm,
  data = exp_data_x,
  y = exp_data_y,
  label = "LightGBM predict",
  colorize = FALSE)
## Preparation of a new explainer is initiated
##   -> model label       :  LightGBM predict 
##   -> data              :  3118  rows  22  cols 
##   -> target variable   :  3118  values 
##   -> predict function  :  yhat.GraphLearner  will be used (  default  )
##   -> predicted values  :  No value for predict function target column. (  default  )
##   -> model_info        :  package mlr3 , ver. 0.20.2 , task classification (  default  ) 
##   -> predicted values  :  numerical, min =  0.0005375626 , mean =  0.4529094 , max =  0.9996255  
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  -0.9996255 , mean =  0.09487763 , max =  0.9994624  
##   A new explainer has been created!
catboost_exp = DALEXtra::explain_mlr3(lrn_catboost,
  data = exp_data_x,
  y = exp_data_y,
  label = "Catboost predict",
  colorize = FALSE)
## Preparation of a new explainer is initiated
##   -> model label       :  Catboost predict 
##   -> data              :  3118  rows  22  cols 
##   -> target variable   :  3118  values 
##   -> predict function  :  yhat.GraphLearner  will be used (  default  )
##   -> predicted values  :  No value for predict function target column. (  default  )
##   -> model_info        :  package mlr3 , ver. 0.20.2 , task classification (  default  ) 
##   -> predicted values  :  numerical, min =  3.795383e-06 , mean =  0.4533994 , max =  0.9999897  
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  -0.9999897 , mean =  0.09438762 , max =  0.9999962  
##   A new explainer has been created!
xgboost_exp = DALEXtra::explain_mlr3(lrn_xgboost,
  data = exp_data_x,
  y = exp_data_y,
  label = "Xgboost predict",
  colorize = FALSE)
## Preparation of a new explainer is initiated
##   -> model label       :  Xgboost predict 
##   -> data              :  3118  rows  22  cols 
##   -> target variable   :  3118  values 
##   -> predict function  :  yhat.GraphLearner  will be used (  default  )
##   -> predicted values  :  No value for predict function target column. (  default  )
##   -> model_info        :  package mlr3 , ver. 0.20.2 , task classification (  default  ) 
##   -> predicted values  :  numerical, min =  8.106232e-06 , mean =  0.4531135 , max =  0.9999851  
##   -> residual function :  difference between y and yhat (  default  )
##   -> residuals         :  numerical, min =  -0.9999851 , mean =  0.09467357 , max =  0.9999919  
##   A new explainer has been created!
lightgbm_exp
## Model label:  LightGBM predict 
## Model class:  GraphLearner,Learner,R6 
## Data head  :
##    cap_color cap_diameter cap_shape cap_surface does_bruise_or_bleed
##       <char>        <num>    <char>      <char>               <lgcl>
## 1:         n         3.97         f           t                FALSE
## 2:         y         1.63         x           g                FALSE
##    gill_attachment gill_color gill_spacing habitat has_ring is_outlier_char
##             <char>     <char>       <char>  <char>   <lgcl>           <num>
## 1:               x          p            c       g    FALSE               0
## 2:               d          n            d       h    FALSE               0
##    is_outlier_num ring_type season spore_print_color stem_color stem_height
##             <num>    <char> <char>            <char>     <char>       <num>
## 1:              0         f      a              UNKN          n        5.44
## 2:              0         f      a              UNKN          n        2.31
##    stem_root stem_surface stem_width veil_color veil_type
##       <char>       <char>      <num>     <char>    <char>
## 1:      UNKN            i       4.30       UNKN      UNKN
## 2:      UNKN            s       2.77       UNKN      UNKN
catboost_exp
## Model label:  Catboost predict 
## Model class:  GraphLearner,Learner,R6 
## Data head  :
##    cap_color cap_diameter cap_shape cap_surface does_bruise_or_bleed
##       <char>        <num>    <char>      <char>               <lgcl>
## 1:         n         3.97         f           t                FALSE
## 2:         y         1.63         x           g                FALSE
##    gill_attachment gill_color gill_spacing habitat has_ring is_outlier_char
##             <char>     <char>       <char>  <char>   <lgcl>           <num>
## 1:               x          p            c       g    FALSE               0
## 2:               d          n            d       h    FALSE               0
##    is_outlier_num ring_type season spore_print_color stem_color stem_height
##             <num>    <char> <char>            <char>     <char>       <num>
## 1:              0         f      a              UNKN          n        5.44
## 2:              0         f      a              UNKN          n        2.31
##    stem_root stem_surface stem_width veil_color veil_type
##       <char>       <char>      <num>     <char>    <char>
## 1:      UNKN            i       4.30       UNKN      UNKN
## 2:      UNKN            s       2.77       UNKN      UNKN
xgboost_exp
## Model label:  Xgboost predict 
## Model class:  GraphLearner,Learner,R6 
## Data head  :
##    cap_color cap_diameter cap_shape cap_surface does_bruise_or_bleed
##       <char>        <num>    <char>      <char>               <lgcl>
## 1:         n         3.97         f           t                FALSE
## 2:         y         1.63         x           g                FALSE
##    gill_attachment gill_color gill_spacing habitat has_ring is_outlier_char
##             <char>     <char>       <char>  <char>   <lgcl>           <num>
## 1:               x          p            c       g    FALSE               0
## 2:               d          n            d       h    FALSE               0
##    is_outlier_num ring_type season spore_print_color stem_color stem_height
##             <num>    <char> <char>            <char>     <char>       <num>
## 1:              0         f      a              UNKN          n        5.44
## 2:              0         f      a              UNKN          n        2.31
##    stem_root stem_surface stem_width veil_color veil_type
##       <char>       <char>      <num>     <char>    <char>
## 1:      UNKN            i       4.30       UNKN      UNKN
## 2:      UNKN            s       2.77       UNKN      UNKN
perf_credit = model_performance(lightgbm_exp)
perf_credit
## Measures for:  classification
## recall     : 0.01346604 
## precision  : 0.01628895 
## f1         : 0.01474359 
## accuracy   : 0.01411161 
## auc        : 0.004515256
## 
## Residuals:
##         0%        10%        20%        30%        40%        50%        60% 
## -0.9996255 -0.9965718 -0.9935112 -0.9887757 -0.9707435  0.9708476  0.9911378 
##        70%        80%        90%       100% 
##  0.9949240  0.9968255  0.9980758  0.9994624
obser = exp_data_x[14, ]
predict(lightgbm_exp, obser)
##          e 
## 0.07786252
lightgbm_predict_parts =  predict_parts(lightgbm_exp, new_observation = obser)
lightgbm_predict_parts
plot(lightgbm_predict_parts)

obser = exp_data_x[14, ]
predict(catboost_exp, obser)
##          e 
## 0.02165201
catboost_predict_parts =  predict_parts(catboost_exp, new_observation = obser)
catboost_predict_parts
plot(catboost_predict_parts)

obser = exp_data_x[14, ]
predict(xgboost_exp, obser)
##          e 
## 0.06630564
xgboost_predict_parts =  predict_parts(xgboost_exp, new_observation = obser)
xgboost_predict_parts
plot(xgboost_predict_parts)