rules

library(tidymodels)
library(rules)

library(rules)
#> Loading required package: parsnip
data(penguins, package = "modeldata")
head(penguins)
## # A tibble: 6 × 7
##   species island    bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
##   <fct>   <fct>              <dbl>         <dbl>             <int>       <int>
## 1 Adelie  Torgersen           39.1          18.7               181        3750
## 2 Adelie  Torgersen           39.5          17.4               186        3800
## 3 Adelie  Torgersen           40.3          18                 195        3250
## 4 Adelie  Torgersen           NA            NA                  NA          NA
## 5 Adelie  Torgersen           36.7          19.3               193        3450
## 6 Adelie  Torgersen           39.3          20.6               190        3650
## # ℹ 1 more variable: sex <fct>
cubist_fit <- 
  cubist_rules(committees = 2) %>% 
  set_engine("Cubist") %>% 
  fit(body_mass_g ~ ., data = penguins)
cubist_fit
## parsnip model object
## 
## 
## Call:
## cubist.default(x = x, y = y, committees = 2)
## 
## Number of samples: 333 
## Number of predictors: 6 
## 
## Number of committees: 2 
## Number of rules per committee: 5, 1
summary(cubist_fit$fit)
## 
## Call:
## cubist.default(x = x, y = y, committees = 2)
## 
## 
## Cubist [Release 2.07 GPL Edition]  Tue Dec 26 23:52:38 2023
## ---------------------------------
## 
##     Target attribute `outcome'
## 
## Read 333 cases (7 attributes) from undefined.data
## 
## Model 1:
## 
##   Rule 1/1: [107 cases, mean 3419.2, range 2700 to 4150, est err 208.3]
## 
##     if
##  flipper_length_mm <= 202
##  sex = female
##     then
##  outcome = -1068 + 108 bill_depth_mm + 10.7 flipper_length_mm
##            + 14 bill_length_mm
## 
##   Rule 1/2: [92 cases, mean 3972.0, range 3250 to 4775, est err 275.6]
## 
##     if
##  flipper_length_mm <= 202
##  sex = male
##     then
##  outcome = 319.1 + 22.3 flipper_length_mm - 21 bill_length_mm
##            + 12 bill_depth_mm
## 
##   Rule 1/3: [58 cases, mean 4679.7, range 3950 to 5200, est err 206.6]
## 
##     if
##  flipper_length_mm > 202
##  sex = female
##     then
##  outcome = -3923.3 + 30.4 flipper_length_mm + 136 bill_depth_mm
##            + 5 bill_length_mm
## 
##   Rule 1/4: [23 cases, mean 4698.9, range 3950 to 6050, est err 275.8]
## 
##     if
##  bill_depth_mm > 16.4
##  flipper_length_mm > 202
##     then
##  outcome = -7845.8 + 58.6 flipper_length_mm
## 
##   Rule 1/5: [53 cases, mean 5475.0, range 4750 to 6300, est err 239.3]
## 
##     if
##  bill_depth_mm <= 16.4
##  sex = male
##     then
##  outcome = -138.7 + 46 bill_length_mm + 89 bill_depth_mm
##            + 8.9 flipper_length_mm
## 
## Model 2:
## 
##   Rule 2/1: [333 cases, mean 4207.1, range 2700 to 6300, est err 315.4]
## 
##  outcome = -5815.1 + 49.7 flipper_length_mm
## 
## 
## Evaluation on training data (333 cases):
## 
##     Average  |error|              278.7
##     Relative |error|               0.41
##     Correlation coefficient        0.90
## 
## 
##  Attribute usage:
##    Conds  Model
## 
##     47%           sex
##     42%   100%    flipper_length_mm
##     11%    47%    bill_depth_mm
##            47%    bill_length_mm
## 
## 
## Time: 0.0 secs
cb_res <- tidy(cubist_fit)
cb_res
## # A tibble: 6 × 5
##   committee rule_num rule                                     estimate statistic
##       <int>    <int> <chr>                                    <list>   <list>   
## 1         1        1 ( sex == 'female' ) & ( flipper_length_… <tibble> <tibble> 
## 2         1        2 ( sex == 'male' ) & ( flipper_length_mm… <tibble> <tibble> 
## 3         1        3 ( flipper_length_mm > 202 ) & ( sex == … <tibble> <tibble> 
## 4         1        4 ( flipper_length_mm > 202 ) & ( bill_de… <tibble> <tibble> 
## 5         1        5 ( bill_depth_mm <= 16.4 ) & ( sex == 'm… <tibble> <tibble> 
## 6         2        1 <no conditions>                          <tibble> <tibble>
library(tidyr)
cb_res %>% 
  dplyr::select(committee, rule_num, statistic) %>% 
  unnest(cols = c(statistic))
## # A tibble: 6 × 8
##   committee rule_num num_conditions coverage  mean   min   max error
##       <int>    <int>          <dbl>    <dbl> <dbl> <dbl> <dbl> <dbl>
## 1         1        1              2      107 3419.  2700  4150  208.
## 2         1        2              2       92 3972   3250  4775  276.
## 3         1        3              2       58 4680.  3950  5200  207.
## 4         1        4              2       23 4699.  3950  6050  276.
## 5         1        5              2       53 5475   4750  6300  239.
## 6         2        1              0      333 4207.  2700  6300  315.
library(dplyr)
library(purrr)
library(rlang)

rule_4_filter <- 
  cb_res %>% 
  dplyr::filter(rule_num == 4) %>% 
  pluck("rule") %>%   # <- character string
  parse_expr() %>%    # <- R expression
  eval_tidy(penguins) # <- logical vector

penguins %>% 
  dplyr::slice(which(rule_4_filter))
## # A tibble: 23 × 7
##    species island    bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
##    <fct>   <fct>              <dbl>         <dbl>             <int>       <int>
##  1 Adelie  Dream               41.1          18.1               205        4300
##  2 Adelie  Dream               40.8          18.9               208        4300
##  3 Adelie  Biscoe              41            20                 203        4725
##  4 Adelie  Torgersen           44.1          18                 210        4000
##  5 Gentoo  Biscoe              59.6          17                 230        6050
##  6 Gentoo  Biscoe              44.4          17.3               219        5250
##  7 Gentoo  Biscoe              49.8          16.8               230        5700
##  8 Gentoo  Biscoe              50.8          17.3               228        5600
##  9 Gentoo  Biscoe              52.1          17                 230        5550
## 10 Gentoo  Biscoe              52.2          17.1               228        5400
## # ℹ 13 more rows
## # ℹ 1 more variable: sex <fct>
penguins <- 
  penguins %>% 
  mutate(body_mass_g = body_mass_g + 0.0) %>% 
  na.omit()


rule_fit_spec <- 
  rule_fit(trees = 10, tree_depth = 5, penalty = 0.01) %>% 
  set_engine("xrf") %>%
  set_mode("regression") 

rule_fit_fit <- 
  rule_fit_spec %>% 
  fit(body_mass_g ~ ., data = penguins)

rule_fit_fit
## parsnip model object
## 
## An eXtreme RuleFit model of 112 rules.
## 
## Original Formula:
## 
## body_mass_g ~ species + island + bill_length_mm + bill_depth_mm + flipper_length_mm + [truncated]
rf_res <- tidy(rule_fit_fit, penalty = 0.01)
rf_res
## # A tibble: 110 × 3
##    rule_id           rule                                               estimate
##    <chr>             <chr>                                                 <dbl>
##  1 (Intercept)       ( TRUE )                                            5753.  
##  2 bill_depth_mm     ( bill_depth_mm )                                    -15.9 
##  3 bill_length_mm    ( bill_length_mm )                                     1.10
##  4 flipper_length_mm ( flipper_length_mm )                                 -7.67
##  5 islandDream       ( island == 'Dream' )                                -29.2 
##  6 islandTorgersen   ( island == 'Torgersen' )                            -16.8 
##  7 r0_2              ( species == 'Gentoo' )                              472.  
##  8 r0_3              ( sex != 'male' ) & ( species != 'Gentoo' )          -78.8 
##  9 r1_2              ( flipper_length_mm >= 211.5 )                       219.  
## 10 r1_3              ( flipper_length_mm <  194.5 ) & ( flipper_length…  -142.  
## # ℹ 100 more rows
rf_variable_res <- tidy(rule_fit_fit, unit = "columns", penalty = 0.01)
rf_variable_res
## # A tibble: 452 × 3
##    rule_id term              estimate
##    <chr>   <chr>                <dbl>
##  1 r0_3    species             -78.8 
##  2 r0_2    species             472.  
##  3 r0_3    sex                 -78.8 
##  4 r1_3    flipper_length_mm  -142.  
##  5 r1_2    flipper_length_mm   219.  
##  6 r1_3    flipper_length_mm  -142.  
##  7 r2_7    flipper_length_mm     9.04
##  8 r2_7    sex                   9.04
##  9 r2_7    species               9.04
## 10 r3_3    flipper_length_mm  -101.  
## # ℹ 442 more rows
num_rules <- sum(grepl("^r[0-9]*_", unique(rf_res$rule_id))) + 1

rf_variable_res %>% 
  dplyr::filter(term != "(Intercept)") %>% 
  group_by(term) %>% 
  summarize(effect = sum(abs(estimate)), .groups = "drop") %>% 
  ungroup() %>% 
  # normalize by number of possible occurrences
  mutate(effect = effect / num_rules ) %>% 
  arrange(desc(effect))
## # A tibble: 6 × 2
##   term              effect
##   <chr>              <dbl>
## 1 flipper_length_mm 219.  
## 2 bill_depth_mm     175.  
## 3 bill_length_mm    175.  
## 4 species            69.5 
## 5 sex                52.0 
## 6 island              9.84