Отметим студентов, которые нуждаются в помощи как тех, кто получиль 10 и меньше на финальном экзамене в 4 модуле
student_mat = student_mat %>%
mutate(help = case_when(
G3<=10 ~ "need_support",
T~ "no_support"
) %>% factor()
) %>%
select(-G3)
Посмотрим на 2 новые переменные
Давайте будем считать что это оценки за 1 и 2 второй модули
student_mat %>%
ggplot(aes(x = G1, y = G2, color = help))+
geom_jitter()+
theme_minimal()
тут мы видим чёткую корреляция между оценками, чем больше оценка за первый модель, тем больше она за 2
стравним наши группе тем методом, что использовали на прошлой паре
library(compareGroups)
compareGroups::compareGroups(help ~ G1 + G2, data = student_mat) %>%
createTable
##
## --------Summary descriptives table by 'help'---------
##
## _____________________________________
## need_support no_support p.overall
## N=186 N=209
## ¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
## G1 8.26 (1.80) 13.3 (2.47) <0.001
## G2 7.70 (2.67) 13.4 (2.26) <0.001
## ¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯
compareGroups рабатает, и даже показывет нам что среднее в двух группах отличается, но нас всё же интересует граница
Gini index это популярная метрика в социальных науках, чтобы исследовать неравнество
https://en.wikipedia.org/wiki/Gini_coefficient
она пришла от кривой Лоренца
и в супер равном обществе люди получают одинаково, т.е
Graphical representation of the Gini coefficien
Представим, что мы все люди зарабатывают одинаково. И в нашем обществе из 10 людей все получают 1.
Так если мы отсортируем всех людей по убыванию зарплаты и к каждой предыдущей зарплате будем прибавлять следующую, в списке, то мы получим прямую
df = data_frame(salary = c(rep(1, 10)))
## Warning: `data_frame()` is deprecated as of tibble 1.1.0.
## Please use `tibble()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
df = df %>%
arrange(-salary) %>%
mutate(position = 1:10) %>%
mutate(prob = (salary / sum (salary) * 100) %>% round(0)) %>% ### TODO
mutate(prob_csum = cumsum(prob))
df %>%
ggplot(aes(x = position, y = prob_csum))+
geom_bar(stat = "identity")+
ylab("%")
Теперь посмотрим на очень неравное общество, где большинство получает 1, а двое получают 2, и 1 получает 5
Теперь у нас получилась совсем не равное общество
df2 = data_frame(salary = c(rep(1, 7), 2, 2, 5))
df2 = df2 %>%
arrange(-salary) %>%
mutate(position = 1:10) %>%
mutate(prob = (salary / sum (salary) * 100) %>% round(0)) %>% ### TODO
mutate(prob_m = mean(prob),
prob_m_csum = cumsum(prob_m),
prob_csum = cumsum(prob))
df2 %>%
ggplot()+
geom_bar(aes(x = position, y = prob_csum), stat = "identity")+
geom_bar(aes(x = position, y = prob_m_csum), stat = "identity", alpha = 0.5, fill = "cyan4")+
ylab("% of salary")
gini index считает разницу между реальным распределением и теоретическием равным. Так чем больше gini, тем больше неравенство
gini index map
formula
один из способов посчитать gini index в R:
# options(dplyr.summarise.inform = FALSE)
gini_process_dplyr <- function(data, classes, splitvar) {
base_prob = data %>%
count({{ classes }}) %>%
group_by({{ classes }}) %>%
summarise(share = (n/nrow(data)))
data = data%>%
count({{ classes }}, {{ splitvar }})%>%
group_by({{ splitvar }}) %>%
mutate(prop_sqr = prop.table(n) **2) %>%
select(-n) %>%
pivot_wider(names_from = {{ splitvar }},
values_from = prop_sqr, names_prefix = "splitvar_") %>%
mutate(
across(everything(), ~replace_na(.x, 0))
) %>%
mutate(gini = 1- colSums(select(., starts_with("splitvar_")))) %>%
left_join(base_prob) %>%
mutate(gini_index = (gini * share))
# pull() %>%
# replace_na(0) %>%
gini_split = data$gini[1]*data$share[1] + data$gini[2]* data$share[2]
print(str_glue("Gini1: {data$gini[1]}, share1: {data$share[1]} \n
Gini2: {data$gini[2]}, share1: {data$share[2]} \n
Gini for split: {gini_split}"))
return(gini_split)
}
?left_join
Т.е. если при одном разбиении 4 студентов, 2 нуждаются в помощи, а 2 не нуждаются, Gini будет 0.25. Это примерно равная группа, а мы ищем не равную.
А если при другом разбиении получится, что 3 будут нуждаться в помощи, а 1 не нуждаться, то Gini где-то 0.1875
data_frame(
group = c("need_help", "need_help", "need_help", "need_help" ,
"no_help", "no_help", "no_help", "no_help"),
split_impure = c(1, 2, 1, 2, 1, 2, 1, 2),
split_2 = c(1 ,1, 1, 2, 2, 2 ,2, 1),
split_perfect = c(1, 1, 1, 1, 2 ,2 ,2 ,2)
) %>%
gini_process_dplyr(group, split_impure)
## Gini1: 0.5, share1: 0.5
##
## Gini2: 0.5, share1: 0.5
##
## Gini for split: 0.5
## splitvar_1
## 0.5
Так в результате мы увидем 2 gini для каждого значения переменной
student_mat %>%
mutate(g1_10 =case_when(
G1 >= 10 ~ ">=10",
T ~ "<10"
)) %>%
gini_process_dplyr(help, g1_10)
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning in `[<-.factor`(`*tmp*`, !is_complete(data), value = 0): invalid factor
## level, NA generated
## Gini1: 0.130926403491371, share1: 0.470886075949367
##
## Gini2: 0.335765283007077, share1: 0.529113924050633
##
## Gini for split: 0.239309506830061
## splitvar_<10
## 0.2393095
Почитаем gini для каждого варианта разбиения по оценкам: оценки меньше 1 и больше, потом меньше 2 и больше, и так до конца. Для каждого разбиения мы увидим gini, который будет показывать насколько равное распределение тех кому нужена помощь.
student_mat %>%
mutate(g1_10 =case_when(
G1 >= 10 ~ "more",
T ~ "less"
)) %>%
gini_process_dplyr(help, g1_10)
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning in `[<-.factor`(`*tmp*`, !is_complete(data), value = 0): invalid factor
## level, NA generated
## Gini1: 0.130926403491371, share1: 0.470886075949367
##
## Gini2: 0.335765283007077, share1: 0.529113924050633
##
## Gini for split: 0.239309506830061
## splitvar_less
## 0.2393095
gini = NA
for (i in 1:20) {
gini = student_mat %>%
mutate(g1_10 =case_when(
G1 >= i ~ "more",
T ~ "less"
)) %>%
gini_process_dplyr(help, g1_10) %>%
rbind(gini)
}
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0.498304758852748, share1: 0.470886075949367
##
## Gini2: 0.498304758852748, share1: 0.529113924050633
##
## Gini for split: 0.498304758852748
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0.498304758852748, share1: 0.470886075949367
##
## Gini2: 0.498304758852748, share1: 0.529113924050633
##
## Gini for split: 0.498304758852748
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0.498304758852748, share1: 0.470886075949367
##
## Gini2: 0.498304758852748, share1: 0.529113924050633
##
## Gini for split: 0.498304758852748
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0, share1: 0.470886075949367
##
## Gini2: 0.498144760236028, share1: 0.529113924050633
##
## Gini for split: 0.263575328833746
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0, share1: 0.470886075949367
##
## Gini2: 0.497976678385745, share1: 0.529113924050633
##
## Gini for split: 0.263486394386382
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0, share1: 0.470886075949367
##
## Gini2: 0.496563666138688, share1: 0.529113924050633
##
## Gini for split: 0.26273874993161
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0, share1: 0.470886075949367
##
## Gini2: 0.488034553279814, share1: 0.529113924050633
##
## Gini for split: 0.25822587755818
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0.0555102040816327, share1: 0.470886075949367
##
## Gini2: 0.462504142011834, share1: 0.529113924050633
##
## Gini for split: 0.270856363644701
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0.118172226280334, share1: 0.470886075949367
##
## Gini2: 0.41073199761952, share1: 0.529113924050633
##
## Gini for split: 0.272969674912967
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0.130926403491371, share1: 0.470886075949367
##
## Gini2: 0.335765283007077, share1: 0.529113924050633
##
## Gini for split: 0.239309506830061
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0.248060350613439, share1: 0.470886075949367
##
## Gini2: 0.186305264189785, share1: 0.529113924050633
##
## Gini for split: 0.215384874505734
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0.333197086801427, share1: 0.470886075949367
##
## Gini2: 0.0361323346757498, share1: 0.529113924050633
##
## Gini for split: 0.176015990107081
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0.42559160599812, share1: 0.470886075949367
##
## Gini2: 0.0155029296875, share1: 0.529113924050633
##
## Gini for split: 0.208607977266678
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0.4712, share1: 0.470886075949367
##
## Gini2: 0, share1: 0.529113924050633
##
## Gini for split: 0.221881518987342
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0.491900826446281, share1: 0.470886075949367
##
## Gini2: 0, share1: 0.529113924050633
##
## Gini for split: 0.23162924992154
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0.498707268026429, share1: 0.470886075949367
##
## Gini2: 0, share1: 0.529113924050633
##
## Gini for split: 0.234834308488394
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0.49994341330919, share1: 0.470886075949367
##
## Gini2: 0, share1: 0.529113924050633
##
## Gini for split: 0.235416392089897
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0.49951171875, share1: 0.470886075949367
##
## Gini2: 0, share1: 0.529113924050633
##
## Gini for split: 0.235213113132911
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0.498698458975427, share1: 0.470886075949367
##
## Gini2: 0, share1: 0.529113924050633
##
## Gini for split: 0.234830160428935
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning: invalid factor level, NA generated
## Gini1: 0.498304758852748, share1: 0.470886075949367
##
## Gini2: 0.498304758852748, share1: 0.529113924050633
##
## Gini for split: 0.498304758852748
Теперь нарисуем как изменяется gini относительно оценки
data_frame(gini = gini) %>%
na.omit() %>%
mutate(i = 1:20) %>%
ggplot()+
geom_line(aes(x = i, y = gini))+
geom_jitter(data = student_mat, aes(x = G1, y = 0.2, color = help),alpha = 0.8, height = 0.20)+
theme_minimal()+
ylab("gini index")+
ylim(c(0, 0.5))
generated = data_frame(
G1 = 1:20,
help = c(rep("need_support", 10) ,rep("no_support", 10))
)
gini_generated = sapply(1:20, function(i){
generated %>%
mutate(g1_10 =case_when(
G1 >= i ~ "more",
T ~ "less"
)) %>%
gini_process_dplyr(help, g1_10)
})
## Gini1: 0.5, share1: 0.5
##
## Gini2: 0.5, share1: 0.5
##
## Gini for split: 0.5
## Gini1: 0, share1: 0.5
##
## Gini2: 0.498614958448754, share1: 0.5
##
## Gini for split: 0.249307479224377
## Gini1: 0, share1: 0.5
##
## Gini2: 0.493827160493827, share1: 0.5
##
## Gini for split: 0.246913580246914
## Gini1: 0, share1: 0.5
##
## Gini2: 0.484429065743945, share1: 0.5
##
## Gini for split: 0.242214532871972
## Gini1: 0, share1: 0.5
##
## Gini2: 0.46875, share1: 0.5
##
## Gini for split: 0.234375
## Gini1: 0, share1: 0.5
##
## Gini2: 0.444444444444444, share1: 0.5
##
## Gini for split: 0.222222222222222
## Gini1: 0, share1: 0.5
##
## Gini2: 0.408163265306122, share1: 0.5
##
## Gini for split: 0.204081632653061
## Gini1: 0, share1: 0.5
##
## Gini2: 0.355029585798816, share1: 0.5
##
## Gini for split: 0.177514792899408
## Gini1: 0, share1: 0.5
##
## Gini2: 0.277777777777778, share1: 0.5
##
## Gini for split: 0.138888888888889
## Gini1: 0, share1: 0.5
##
## Gini2: 0.165289256198347, share1: 0.5
##
## Gini for split: 0.0826446280991736
## Gini1: 0, share1: 0.5
##
## Gini2: 0, share1: 0.5
##
## Gini for split: 0
## Gini1: 0.165289256198347, share1: 0.5
##
## Gini2: 0, share1: 0.5
##
## Gini for split: 0.0826446280991736
## Gini1: 0.277777777777778, share1: 0.5
##
## Gini2: 0, share1: 0.5
##
## Gini for split: 0.138888888888889
## Gini1: 0.355029585798816, share1: 0.5
##
## Gini2: 0, share1: 0.5
##
## Gini for split: 0.177514792899408
## Gini1: 0.408163265306122, share1: 0.5
##
## Gini2: 0, share1: 0.5
##
## Gini for split: 0.204081632653061
## Gini1: 0.444444444444444, share1: 0.5
##
## Gini2: 0, share1: 0.5
##
## Gini for split: 0.222222222222222
## Gini1: 0.46875, share1: 0.5
##
## Gini2: 0, share1: 0.5
##
## Gini for split: 0.234375
## Gini1: 0.484429065743945, share1: 0.5
##
## Gini2: 0, share1: 0.5
##
## Gini for split: 0.242214532871972
## Gini1: 0.493827160493827, share1: 0.5
##
## Gini2: 0, share1: 0.5
##
## Gini for split: 0.246913580246914
## Gini1: 0.498614958448754, share1: 0.5
##
## Gini2: 0, share1: 0.5
##
## Gini for split: 0.249307479224377
data_frame(gini = gini_generated) %>%
mutate(i = 1:20) %>%
ggplot()+
geom_line(aes(x = i, y = gini))+
geom_jitter(data = generated, aes(x = G1, y = 0.2, color = help),alpha = 0.8, height = 0.20)+
theme_minimal()+
ylab("gini index")+
ylim(c(0, 0.5))
То же самое сделаем с другой бинарной переменной. Какой нужно выбрать?
student_mat %>%
gini_process_dplyr(help, higher)
## Warning: Problem with `mutate()` input `..1`.
## ℹ invalid factor level, NA generated
## ℹ Input `..1` is `across(everything(), ~replace_na(.x, 0))`.
## Warning in `[<-.factor`(`*tmp*`, !is_complete(data), value = 0): invalid factor
## level, NA generated
## Gini1: 0.32, share1: 0.470886075949367
##
## Gini2: 0.495644444444445, share1: 0.529113924050633
##
## Gini for split: 0.412935921237693
## splitvar_no
## 0.4129359
Нарисуем наши разбиения
student_mat %>%
ggplot(aes(x = G1, y = G2, color = help))+
geom_jitter(alpha = 0.8)+
theme_minimal()+
annotate("rect", alpha = .1, fill = "blue", color = 1,
xmin = 0, xmax = 9,
ymin = -Inf, ymax = Inf)+
annotate("rect", alpha = .1, fill = "yellow", color = 1,
xmin = -Inf, xmax = Inf,
ymin = 0, ymax = 10)
student_mat %>%
ggplot(aes(x = higher, fill = help))+
geom_bar(position = "fill")
student_mat %>%
ggplot(aes(x = higher, fill = help))+
geom_bar()
Другая метрика это information gain, которая связана с изменением энтропии.
Эта метрика показывает вероятноять выбрать переменную которая принадлежит к интересному для нас классу, при условии, что мы будем выбирать случайно
Так если после некоторого разделения у нас осталось 4 студента и все из них в группе, которой нужна помощь, то мы получаем 0 энтропию, так как точно можем сказать студента из какой группы мы выберем, если будем выбирать случайно.
Information Gain отражает разницу между суммой энтропии класса и атрибута минус вероятность класса при условии атрибута
Так если у атриутов по отдельности высокая энтропия: из 4 студентов 2 нуждаются а 2 не нуждаются в помощи, а при сочетании с переменной высокая оценка
library(FSelectorRcpp)
?information_gain
information_gain( # Calculate the score for each attribute
formula = help ~ ., # that is on the right side of the formula.
data = student_mat, # Attributes must exist in the passed data.
type = "infogain" # Choose the type of a score to be calculated.
) %>%
arrange(-importance)
## attributes importance
## 1 G2 5.190839e-01
## 2 G1 4.061086e-01
## 3 failures 4.659940e-02
## 4 Fedu 2.150424e-02
## 5 Mjob 2.043180e-02
## 6 Medu 1.855657e-02
## 7 higher 1.220407e-02
## 8 schoolsup 9.296121e-03
## 9 address 5.443534e-03
## 10 Fjob 4.593047e-03
## 11 internet 4.426085e-03
## 12 school 3.564076e-03
## 13 guardian 2.684456e-03
## 14 sex 2.571527e-03
## 15 reason 2.480221e-03
## 16 Pstatus 1.526086e-03
## 17 paid 9.277553e-04
## 18 famsize 8.505739e-04
## 19 famsup 5.031213e-04
## 20 nursery 3.626333e-04
## 21 romantic 4.107437e-05
## 22 activities 2.161684e-05
## 23 age 0.000000e+00
## 24 traveltime 0.000000e+00
## 25 studytime 0.000000e+00
## 26 famrel 0.000000e+00
## 27 freetime 0.000000e+00
## 28 goout 0.000000e+00
## 29 Dalc 0.000000e+00
## 30 Walc 0.000000e+00
## 31 health 0.000000e+00
## 32 absences 0.000000e+00
# %>%
# cut_attrs( # Then take attributes with the highest rank.
# k = 2 # For example: 2 attrs with the higehst rank.
# ) %>%
# to_formula( # Create a new formula object with
# attrs = ., # the most influencial attrs.
# class = "Species"
# )
ещё одна важная часть
?initial_split
student_mat_split <- initial_split(student_mat, prop = 3/4)
train_data <- training(student_mat_split)
test_data <- testing(student_mat_split)
train_data
## # A tibble: 297 x 33
## school sex age address famsize Pstatus Medu Fedu Mjob Fjob reason
## <fct> <fct> <dbl> <fct> <fct> <fct> <fct> <fct> <fct> <fct> <fct>
## 1 GP F 15 U GT3 T high… prim… heal… serv… home
## 2 GP F 16 U GT3 T seco… seco… other other home
## 3 GP M 15 U GT3 T seco… high… other other home
## 4 GP F 15 U GT3 T prim… prim… serv… other reput…
## 5 GP M 15 U LE3 T high… high… heal… serv… course
## 6 GP M 15 U GT3 T high… seco… teac… other course
## 7 GP M 15 U GT3 A prim… prim… other other home
## 8 GP F 16 U GT3 T high… high… serv… serv… reput…
## 9 GP F 16 U GT3 T seco… seco… other other reput…
## 10 GP M 16 U LE3 T high… seco… heal… other home
## # … with 287 more rows, and 22 more variables: guardian <fct>,
## # traveltime <dbl>, studytime <dbl>, failures <dbl>, schoolsup <fct>,
## # famsup <fct>, paid <fct>, activities <fct>, nursery <fct>, higher <fct>,
## # internet <fct>, romantic <fct>, famrel <dbl>, freetime <dbl>, goout <dbl>,
## # Dalc <dbl>, Walc <dbl>, health <dbl>, absences <dbl>, G1 <dbl>, G2 <dbl>,
## # help <fct>
library(partykit)
library(ggparty)
?partysplit
ps1 = partykit::partysplit(which(names(student_mat) == "G1"), breaks = 9)
ps2 = partykit::partysplit(which(names(student_mat) == "G2"), breaks = 10)
# ps11 = partysplit(which(names(adult_data) == "education-num"), breaks = 12)
ps111 = partykit::partysplit(which(names(student_mat) == "higher"), index = c(1L, 2L))
partykit::character_split(ps111, student_mat)
## $name
## [1] "higher"
##
## $levels
## [1] "no" "yes"
which(names(student_mat) == "higher")
## [1] 21
# pn1 = partynode(1, split = ps1,
# kids = list(
# partynode(id = 1, info = ">50K"),
# partynode(id = 2, info = "<=50K")
# ))
pn1 = partynode(1L, split = ps1,
kids = list(
partynode(2L, split = ps2, kids = list(
partynode(4L, split = ps111, kids = list(
partynode(6L, info = "need_support"),
partynode(7L, info = "no_support")
)),
partynode(5L, info = "no_support")
)),
partynode(3L, info = "no_support")
))
pn1
## [1] root
## | [2] V31 <= 9
## | | [4] V32 <= 10
## | | | [6] V21 <= 1 *
## | | | [7] V21 > 1 *
## | | [5] V32 > 10 *
## | [3] V31 > 9 *
py1 <- party(pn1, student_mat)
py1
## [1] root
## | [2] G1 <= 9
## | | [3] G2 <= 10
## | | | [4] higher in no: need_support
## | | | [5] higher in yes: no_support
## | | [6] G2 > 10: no_support
## | [7] G1 > 9: no_support
plot(py1)
ggparty(py1) +
geom_edge() +
# geom_edge_label() +
geom_node_splitvar() +
# pass list to gglist containing all ggplot components we want to plot for each
# (default: terminal) node
geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = help),
position = position_fill()),
xlab("help"),
ylab("%"),
theme_minimal()))
gp1 =ggparty(py1, terminal_space = 0.2) +
geom_edge() +
# geom_edge_label() +
geom_node_splitvar() +
geom_node_plot(ids = 1,
# size = 1.5,
# height = 0.5,
gglist = list(
geom_histogram(
aes(
x = G1, fill = help
)),
geom_vline(aes(xintercept = 9)),
theme_bw(),
theme(axis.title.y = element_blank(), legend.position='none')
))+
geom_node_plot(ids = 2,
# size = 1.5,
# height = 0.5,
gglist = list(
geom_histogram(
aes(
x = G2, fill = help
)),
geom_vline(aes(xintercept = 10)),
theme_bw(),
theme(axis.title.y = element_blank(), legend.position='none')
))+
geom_node_plot(ids = 3,
# size = 1.5,
# height = 0.5,
gglist = list(
geom_bar(
aes(
x = higher, fill = help
), position = "fill"),
# scale_alpha_discrete(range = c(0.5, 1)),
theme_bw(),
theme(axis.title.y = element_blank(), axis.text.x = element_text(angle = 45, hjust = 1), legend.position='none')
))+
# pass list to gglist containing all ggplot components we want to plot for each
# (default: terminal) node
geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = help),
position = position_fill()),
xlab("help"),
ylab("%"),
theme_minimal(),
theme(legend.position='none')))+
theme(legend.position = "none")
gp1
ggsave(gp1,scale = 0.4, filename = "ggparty_students.png")