Data

Отметим студентов, которые нуждаются в помощи как тех, кто получиль 10 и меньше на финальном экзамене в 4 модуле

student_mat = student_mat %>% 
  mutate(help = case_when(
    G3<=10 ~ "need_support",
    T~ "no_support"
  ) %>% factor()
  ) %>% 
  select(-G3)

Посмотрим на 2 новые переменные

Давайте будем считать что это оценки за 1 и 2 второй модули

student_mat %>% 
  ggplot(aes(x = G1, y = G2, color = help))+
  geom_jitter()+
  theme_minimal()

тут мы видим чёткую корреляция между оценками, чем больше оценка за первый модель, тем больше она за 2

стравним наши группе тем методом, что использовали на прошлой паре

library(compareGroups)
compareGroups::compareGroups(help ~ G1 + G2, data = student_mat) %>% 
  createTable
## 
## --------Summary descriptives table by 'help'---------
## 
## _____________________________________ 
##    need_support no_support  p.overall 
##       N=186        N=209              
## ¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯ 
## G1 8.26 (1.80)  13.3 (2.47)  <0.001   
## G2 7.70 (2.67)  13.4 (2.26)  <0.001   
## ¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯

compareGroups рабатает, и даже показывет нам что среднее в двух группах отличается, но нас всё же интересует граница

Gini index

Gini index это популярная метрика в социальных науках, чтобы исследовать неравнество

https://en.wikipedia.org/wiki/Gini_coefficient

она пришла от кривой Лоренца

и в супер равном обществе люди получают одинаково, т.е

Graphical representation of the Gini coefficien

Представим, что мы все люди зарабатывают одинаково. И в нашем обществе из 10 людей все получают 1.

Так если мы отсортируем всех людей по убыванию зарплаты и к каждой предыдущей зарплате будем прибавлять следующую, в списке, то мы получим прямую

df = data_frame(salary = c(rep(1, 10)))
## Warning: `data_frame()` is deprecated as of tibble 1.1.0.
## Please use `tibble()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
df = df %>% 
  arrange(-salary) %>% 
  mutate(position = 1:10) %>% 
  mutate(prob = (salary / sum (salary) * 100) %>% round(0)) %>% ### TODO
  mutate(prob_csum = cumsum(prob))

df %>% 
ggplot(aes(x = position, y = prob_csum))+
  geom_bar(stat = "identity")+
  ylab("%")

Теперь посмотрим на очень неравное общество, где большинство получает 1, а двое получают 2, и 1 получает 5

Теперь у нас получилась совсем не равное общество

df2 = data_frame(salary = c(rep(1, 7), 2, 2, 5))
df2 = df2 %>% 
  arrange(-salary) %>% 
  mutate(position = 1:10) %>% 
  mutate(prob = (salary / sum (salary) * 100) %>% round(0)) %>% ### TODO
  mutate(prob_m = mean(prob),
         prob_m_csum = cumsum(prob_m),
    prob_csum = cumsum(prob))

df2 %>% 
ggplot()+
  geom_bar(aes(x = position, y = prob_csum), stat = "identity")+
  geom_bar(aes(x = position, y = prob_m_csum), stat = "identity", alpha = 0.5, fill = "cyan4")+
  ylab("% of salary")

gini index считает разницу между реальным распределением и теоретическием равным. Так чем больше gini, тем больше неравенство

gini index map

formula

один из способов посчитать gini index в R:

  # options(dplyr.summarise.inform = FALSE)

gini_process_dplyr <- function(data, classes, splitvar) {
  
   base_prob = data %>% 
      count({{ classes }})  %>% 
      group_by({{ classes }}) %>% 
      summarise(share = (n/nrow(data)))
   
    data = data%>%
      count({{ classes }}, {{ splitvar }})%>%
      group_by({{ splitvar }}) %>%
      mutate(prop_sqr = prop.table(n) **2) %>%
      select(-n) %>%
      pivot_wider(names_from = {{ splitvar }},
                  values_from = prop_sqr, names_prefix = "splitvar_") %>%
              mutate(
      across(everything(), ~replace_na(.x, 0))
    ) %>% 
      mutate(gini = 1- colSums(select(., starts_with("splitvar_")))) %>%
      left_join(base_prob) %>%
        mutate(gini_index = (gini * share))
        # pull() %>%
        # replace_na(0) %>%
    
    gini_split = data$gini[1]*data$share[1] +  data$gini[2]* data$share[2]
    print(str_glue("Gini1: {data$gini[1]}, share1: {data$share[1]} \n
                   Gini2: {data$gini[2]}, share1: {data$share[2]} \n
                   Gini for split: {gini_split}"))
    
    return(gini_split)
}
?left_join

Т.е. если при одном разбиении 4 студентов, 2 нуждаются в помощи, а 2 не нуждаются, Gini будет 0.25. Это примерно равная группа, а мы ищем не равную.

А если при другом разбиении получится, что 3 будут нуждаться в помощи, а 1 не нуждаться, то Gini где-то 0.1875

data_frame(
  group = c("need_help", "need_help", "need_help", "need_help" ,
            "no_help", "no_help", "no_help", "no_help"),
  split_impure = c(1, 2, 1, 2, 1, 2, 1, 2),
  split_2 = c(1 ,1, 1, 2, 2, 2 ,2, 1),
  split_perfect = c(1, 1, 1, 1, 2 ,2 ,2 ,2)
           ) %>% 
  gini_process_dplyr(group, split_impure)
## Gini1: 0.5, share1: 0.5 
## 
## Gini2: 0.5, share1: 0.5 
## 
## Gini for split: 0.5
## splitvar_1 
##        0.5

Так в результате мы увидем 2 gini для каждого значения переменной

student_mat %>% 
  mutate(g1_10 =case_when(
    G1 >= 10 ~ ">=10",
    T ~ "<10"
  )) %>% 
gini_process_dplyr(help, g1_10)
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning in `[<-.factor`(`*tmp*`, !is_complete(data), value = 0): invalid factor
## level, NA generated
## Gini1: 0.130926403491371, share1: 0.470886075949367 
## 
## Gini2: 0.335765283007077, share1: 0.529113924050633 
## 
## Gini for split: 0.239309506830061
## splitvar_<10 
##    0.2393095

Почитаем gini для каждого варианта разбиения по оценкам: оценки меньше 1 и больше, потом меньше 2 и больше, и так до конца. Для каждого разбиения мы увидим gini, который будет показывать насколько равное распределение тех кому нужена помощь.

student_mat %>% 
  mutate(g1_10 =case_when(
    G1 >= 10 ~ "more",
    T ~ "less"
  )) %>% 
gini_process_dplyr(help, g1_10) 
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning in `[<-.factor`(`*tmp*`, !is_complete(data), value = 0): invalid factor
## level, NA generated
## Gini1: 0.130926403491371, share1: 0.470886075949367 
## 
## Gini2: 0.335765283007077, share1: 0.529113924050633 
## 
## Gini for split: 0.239309506830061
## splitvar_less 
##     0.2393095
gini = NA

for (i in 1:20) {
  gini = student_mat %>% 
  mutate(g1_10 =case_when(
    G1 >= i ~ "more",
    T ~ "less"
  )) %>% 
gini_process_dplyr(help, g1_10) %>% 
    rbind(gini)
}
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0.498304758852748, share1: 0.470886075949367 
## 
## Gini2: 0.498304758852748, share1: 0.529113924050633 
## 
## Gini for split: 0.498304758852748
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0.498304758852748, share1: 0.470886075949367 
## 
## Gini2: 0.498304758852748, share1: 0.529113924050633 
## 
## Gini for split: 0.498304758852748
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0.498304758852748, share1: 0.470886075949367 
## 
## Gini2: 0.498304758852748, share1: 0.529113924050633 
## 
## Gini for split: 0.498304758852748
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0, share1: 0.470886075949367 
## 
## Gini2: 0.498144760236028, share1: 0.529113924050633 
## 
## Gini for split: 0.263575328833746
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0, share1: 0.470886075949367 
## 
## Gini2: 0.497976678385745, share1: 0.529113924050633 
## 
## Gini for split: 0.263486394386382
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0, share1: 0.470886075949367 
## 
## Gini2: 0.496563666138688, share1: 0.529113924050633 
## 
## Gini for split: 0.26273874993161
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0, share1: 0.470886075949367 
## 
## Gini2: 0.488034553279814, share1: 0.529113924050633 
## 
## Gini for split: 0.25822587755818
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0.0555102040816327, share1: 0.470886075949367 
## 
## Gini2: 0.462504142011834, share1: 0.529113924050633 
## 
## Gini for split: 0.270856363644701
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0.118172226280334, share1: 0.470886075949367 
## 
## Gini2: 0.41073199761952, share1: 0.529113924050633 
## 
## Gini for split: 0.272969674912967
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0.130926403491371, share1: 0.470886075949367 
## 
## Gini2: 0.335765283007077, share1: 0.529113924050633 
## 
## Gini for split: 0.239309506830061
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0.248060350613439, share1: 0.470886075949367 
## 
## Gini2: 0.186305264189785, share1: 0.529113924050633 
## 
## Gini for split: 0.215384874505734
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0.333197086801427, share1: 0.470886075949367 
## 
## Gini2: 0.0361323346757498, share1: 0.529113924050633 
## 
## Gini for split: 0.176015990107081
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0.42559160599812, share1: 0.470886075949367 
## 
## Gini2: 0.0155029296875, share1: 0.529113924050633 
## 
## Gini for split: 0.208607977266678
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0.4712, share1: 0.470886075949367 
## 
## Gini2: 0, share1: 0.529113924050633 
## 
## Gini for split: 0.221881518987342
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0.491900826446281, share1: 0.470886075949367 
## 
## Gini2: 0, share1: 0.529113924050633 
## 
## Gini for split: 0.23162924992154
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0.498707268026429, share1: 0.470886075949367 
## 
## Gini2: 0, share1: 0.529113924050633 
## 
## Gini for split: 0.234834308488394
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0.49994341330919, share1: 0.470886075949367 
## 
## Gini2: 0, share1: 0.529113924050633 
## 
## Gini for split: 0.235416392089897
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0.49951171875, share1: 0.470886075949367 
## 
## Gini2: 0, share1: 0.529113924050633 
## 
## Gini for split: 0.235213113132911
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0.498698458975427, share1: 0.470886075949367 
## 
## Gini2: 0, share1: 0.529113924050633 
## 
## Gini for split: 0.234830160428935
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.

## Warning: invalid factor level, NA generated
## Gini1: 0.498304758852748, share1: 0.470886075949367 
## 
## Gini2: 0.498304758852748, share1: 0.529113924050633 
## 
## Gini for split: 0.498304758852748

Теперь нарисуем как изменяется gini относительно оценки

data_frame(gini = gini) %>% 
  na.omit() %>% 
  mutate(i = 1:20) %>% 
  ggplot()+
  geom_line(aes(x = i, y = gini))+
  geom_jitter(data = student_mat, aes(x = G1, y = 0.2, color = help),alpha = 0.8, height = 0.20)+
  theme_minimal()+
  ylab("gini index")+
  ylim(c(0, 0.5))

generated data

generated = data_frame(
  G1 = 1:20,
  help = c(rep("need_support", 10) ,rep("no_support", 10))
)


gini_generated = sapply(1:20, function(i){
  generated %>% 
  mutate(g1_10 =case_when(
    G1 >= i ~ "more",
    T ~ "less"
  )) %>% 
gini_process_dplyr(help, g1_10)
})
## Gini1: 0.5, share1: 0.5 
## 
## Gini2: 0.5, share1: 0.5 
## 
## Gini for split: 0.5
## Gini1: 0, share1: 0.5 
## 
## Gini2: 0.498614958448754, share1: 0.5 
## 
## Gini for split: 0.249307479224377
## Gini1: 0, share1: 0.5 
## 
## Gini2: 0.493827160493827, share1: 0.5 
## 
## Gini for split: 0.246913580246914
## Gini1: 0, share1: 0.5 
## 
## Gini2: 0.484429065743945, share1: 0.5 
## 
## Gini for split: 0.242214532871972
## Gini1: 0, share1: 0.5 
## 
## Gini2: 0.46875, share1: 0.5 
## 
## Gini for split: 0.234375
## Gini1: 0, share1: 0.5 
## 
## Gini2: 0.444444444444444, share1: 0.5 
## 
## Gini for split: 0.222222222222222
## Gini1: 0, share1: 0.5 
## 
## Gini2: 0.408163265306122, share1: 0.5 
## 
## Gini for split: 0.204081632653061
## Gini1: 0, share1: 0.5 
## 
## Gini2: 0.355029585798816, share1: 0.5 
## 
## Gini for split: 0.177514792899408
## Gini1: 0, share1: 0.5 
## 
## Gini2: 0.277777777777778, share1: 0.5 
## 
## Gini for split: 0.138888888888889
## Gini1: 0, share1: 0.5 
## 
## Gini2: 0.165289256198347, share1: 0.5 
## 
## Gini for split: 0.0826446280991736
## Gini1: 0, share1: 0.5 
## 
## Gini2: 0, share1: 0.5 
## 
## Gini for split: 0
## Gini1: 0.165289256198347, share1: 0.5 
## 
## Gini2: 0, share1: 0.5 
## 
## Gini for split: 0.0826446280991736
## Gini1: 0.277777777777778, share1: 0.5 
## 
## Gini2: 0, share1: 0.5 
## 
## Gini for split: 0.138888888888889
## Gini1: 0.355029585798816, share1: 0.5 
## 
## Gini2: 0, share1: 0.5 
## 
## Gini for split: 0.177514792899408
## Gini1: 0.408163265306122, share1: 0.5 
## 
## Gini2: 0, share1: 0.5 
## 
## Gini for split: 0.204081632653061
## Gini1: 0.444444444444444, share1: 0.5 
## 
## Gini2: 0, share1: 0.5 
## 
## Gini for split: 0.222222222222222
## Gini1: 0.46875, share1: 0.5 
## 
## Gini2: 0, share1: 0.5 
## 
## Gini for split: 0.234375
## Gini1: 0.484429065743945, share1: 0.5 
## 
## Gini2: 0, share1: 0.5 
## 
## Gini for split: 0.242214532871972
## Gini1: 0.493827160493827, share1: 0.5 
## 
## Gini2: 0, share1: 0.5 
## 
## Gini for split: 0.246913580246914
## Gini1: 0.498614958448754, share1: 0.5 
## 
## Gini2: 0, share1: 0.5 
## 
## Gini for split: 0.249307479224377
data_frame(gini = gini_generated) %>% 
  mutate(i = 1:20) %>% 
  ggplot()+
  geom_line(aes(x = i, y = gini))+
  geom_jitter(data = generated, aes(x = G1, y = 0.2, color = help),alpha = 0.8, height = 0.20)+
  theme_minimal()+
  ylab("gini index")+
  ylim(c(0, 0.5))

То же самое сделаем с другой бинарной переменной. Какой нужно выбрать?

student_mat %>% 
  gini_process_dplyr(help, higher)
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning in `[<-.factor`(`*tmp*`, !is_complete(data), value = 0): invalid factor
## level, NA generated
## Gini1: 0.32, share1: 0.470886075949367 
## 
## Gini2: 0.495644444444445, share1: 0.529113924050633 
## 
## Gini for split: 0.412935921237693
## splitvar_no 
##   0.4129359

Нарисуем наши разбиения

student_mat %>% 
  ggplot(aes(x = G1, y = G2, color = help))+
  geom_jitter(alpha = 0.8)+
  theme_minimal()+
    annotate("rect", alpha = .1, fill = "blue", color = 1,
           xmin = 0, xmax = 9, 
           ymin = -Inf, ymax = Inf)+
      annotate("rect", alpha = .1, fill = "yellow", color = 1,
           xmin = -Inf, xmax = Inf, 
           ymin = 0, ymax = 10)

student_mat %>% 
  ggplot(aes(x = higher, fill = help))+
  geom_bar(position = "fill")

student_mat %>% 
  ggplot(aes(x = higher, fill = help))+
  geom_bar()

Другая метрика это information gain, которая связана с изменением энтропии.

Эта метрика показывает вероятноять выбрать переменную которая принадлежит к интересному для нас классу, при условии, что мы будем выбирать случайно

Так если после некоторого разделения у нас осталось 4 студента и все из них в группе, которой нужна помощь, то мы получаем 0 энтропию, так как точно можем сказать студента из какой группы мы выберем, если будем выбирать случайно.

Information Gain отражает разницу между суммой энтропии класса и атрибута минус вероятность класса при условии атрибута

Так если у атриутов по отдельности высокая энтропия: из 4 студентов 2 нуждаются а 2 не нуждаются в помощи, а при сочетании с переменной высокая оценка

Information Gain

library(FSelectorRcpp)

?information_gain


information_gain(               # Calculate the score for each attribute
    formula = help ~ .,      # that is on the right side of the formula.
    data = student_mat,                # Attributes must exist in the passed data.
    type  = "infogain"          # Choose the type of a score to be calculated.
  ) %>% 
  arrange(-importance)
##    attributes   importance
## 1          G2 5.190839e-01
## 2          G1 4.061086e-01
## 3    failures 4.659940e-02
## 4        Fedu 2.150424e-02
## 5        Mjob 2.043180e-02
## 6        Medu 1.855657e-02
## 7      higher 1.220407e-02
## 8   schoolsup 9.296121e-03
## 9     address 5.443534e-03
## 10       Fjob 4.593047e-03
## 11   internet 4.426085e-03
## 12     school 3.564076e-03
## 13   guardian 2.684456e-03
## 14        sex 2.571527e-03
## 15     reason 2.480221e-03
## 16    Pstatus 1.526086e-03
## 17       paid 9.277553e-04
## 18    famsize 8.505739e-04
## 19     famsup 5.031213e-04
## 20    nursery 3.626333e-04
## 21   romantic 4.107437e-05
## 22 activities 2.161684e-05
## 23        age 0.000000e+00
## 24 traveltime 0.000000e+00
## 25  studytime 0.000000e+00
## 26     famrel 0.000000e+00
## 27   freetime 0.000000e+00
## 28      goout 0.000000e+00
## 29       Dalc 0.000000e+00
## 30       Walc 0.000000e+00
## 31     health 0.000000e+00
## 32   absences 0.000000e+00
# %>%                          
  # cut_attrs(                    # Then take attributes with the highest rank.
  #   k = 2                       # For example: 2 attrs with the higehst rank.
  # ) %>%                         
  # to_formula(                   # Create a new formula object with 
  #   attrs = .,                  # the most influencial attrs.
  #   class = "Species"           
  # )

Test train

ещё одна важная часть

?initial_split

student_mat_split <- initial_split(student_mat, prop = 3/4)
train_data <- training(student_mat_split)
test_data <- testing(student_mat_split)
train_data
## # A tibble: 297 x 33
##    school sex     age address famsize Pstatus Medu  Fedu  Mjob  Fjob  reason
##    <fct>  <fct> <dbl> <fct>   <fct>   <fct>   <fct> <fct> <fct> <fct> <fct> 
##  1 GP     F        15 U       GT3     T       high… prim… heal… serv… home  
##  2 GP     F        16 U       GT3     T       seco… seco… other other home  
##  3 GP     M        15 U       GT3     T       seco… high… other other home  
##  4 GP     F        15 U       GT3     T       prim… prim… serv… other reput…
##  5 GP     M        15 U       LE3     T       high… high… heal… serv… course
##  6 GP     M        15 U       GT3     T       high… seco… teac… other course
##  7 GP     M        15 U       GT3     A       prim… prim… other other home  
##  8 GP     F        16 U       GT3     T       high… high… serv… serv… reput…
##  9 GP     F        16 U       GT3     T       seco… seco… other other reput…
## 10 GP     M        16 U       LE3     T       high… seco… heal… other home  
## # … with 287 more rows, and 22 more variables: guardian <fct>,
## #   traveltime <dbl>, studytime <dbl>, failures <dbl>, schoolsup <fct>,
## #   famsup <fct>, paid <fct>, activities <fct>, nursery <fct>, higher <fct>,
## #   internet <fct>, romantic <fct>, famrel <dbl>, freetime <dbl>, goout <dbl>,
## #   Dalc <dbl>, Walc <dbl>, health <dbl>, absences <dbl>, G1 <dbl>, G2 <dbl>,
## #   help <fct>

ggparty

library(partykit)
library(ggparty)
?partysplit
ps1 = partykit::partysplit(which(names(student_mat) == "G1"), breaks = 9)
ps2 = partykit::partysplit(which(names(student_mat) == "G2"), breaks = 10)

# ps11 = partysplit(which(names(adult_data) == "education-num"), breaks = 12)

ps111 = partykit::partysplit(which(names(student_mat) == "higher"), index = c(1L, 2L))
partykit::character_split(ps111, student_mat)
## $name
## [1] "higher"
## 
## $levels
## [1] "no"  "yes"
which(names(student_mat) == "higher")
## [1] 21
# pn1 = partynode(1, split = ps1, 
#                 kids = list(
#                   partynode(id = 1, info = ">50K"),
#                   partynode(id = 2, info = "<=50K")
#                 ))

pn1 = partynode(1L, split = ps1, 
                kids = list(
                  partynode(2L, split = ps2, kids = list(
                    partynode(4L, split = ps111, kids = list(
                      partynode(6L, info = "need_support"),
                      partynode(7L, info = "no_support")
                    )),
                    partynode(5L, info = "no_support")
                  )),
                  partynode(3L, info = "no_support")
                ))
pn1
## [1] root
## |   [2] V31 <= 9
## |   |   [4] V32 <= 10
## |   |   |   [6] V21 <= 1 *
## |   |   |   [7] V21 > 1 *
## |   |   [5] V32 > 10 *
## |   [3] V31 > 9 *
py1 <- party(pn1, student_mat)
py1
## [1] root
## |   [2] G1 <= 9
## |   |   [3] G2 <= 10
## |   |   |   [4] higher in no: need_support
## |   |   |   [5] higher in yes: no_support
## |   |   [6] G2 > 10: no_support
## |   [7] G1 > 9: no_support
plot(py1)

ggparty(py1) +
  geom_edge() +
  # geom_edge_label() +
  geom_node_splitvar() +
  # pass list to gglist containing all ggplot components we want to plot for each
  # (default: terminal) node
  geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = help),
                                        position = position_fill()),
                               xlab("help"),
                               ylab("%"),
                               theme_minimal()))

gp1 =ggparty(py1, terminal_space = 0.2) +
  geom_edge() +
  # geom_edge_label() +
  geom_node_splitvar() +
  geom_node_plot(ids = 1,
                 # size = 1.5,
                 # height = 0.5,
                 gglist = list(
                   geom_histogram(
                     aes(
                     x = G1, fill = help
                   )),
                   geom_vline(aes(xintercept = 9)),
                   theme_bw(),
                   theme(axis.title.y = element_blank(), legend.position='none')
                 ))+
    geom_node_plot(ids = 2, 
                 # size = 1.5,
                 # height = 0.5,
                 gglist = list(
                   geom_histogram(
                     aes(
                     x = G2, fill = help
                   )),
                   geom_vline(aes(xintercept = 10)),
                   theme_bw(),
                   theme(axis.title.y = element_blank(), legend.position='none')
                 ))+
      geom_node_plot(ids = 3, 
                     # size = 1.5,
                 # height = 0.5,
                 gglist = list(
                   geom_bar(
                     aes(
                     x = higher, fill = help
                   ), position = "fill"),
                   # scale_alpha_discrete(range = c(0.5, 1)),
                   theme_bw(),
                   theme(axis.title.y = element_blank(), axis.text.x = element_text(angle = 45, hjust = 1), legend.position='none')
                 ))+
  # pass list to gglist containing all ggplot components we want to plot for each
  # (default: terminal) node
  geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = help),
                                        position = position_fill()),
                               xlab("help"),
                               ylab("%"),
                               theme_minimal(),
                               theme(legend.position='none')))+
  theme(legend.position = "none")

gp1

ggsave(gp1,scale = 0.4, filename =  "ggparty_students.png")

Tasks

  • Зачем нужен test/train split?
  • Что измеряет gini index?