ביחידה זו נלמד על שתי שיטות נוספות המשמשות לחיזוי ורגרסיה: עצים, ויערות אקראיים (trees and random forests).

תזכורת: באחת היחידות הקודמות התאמנו מודל רגרסיה לינארית למחיר יהלומים. כאשר פיצלנו את המודל לשני מודלים העובדים כל אחד על תחום של משתנה קטגורי, איכות החיזוי של המודל השתפרה.

הרעיון הכללי של עצים הוא דומה - לפצל את המרחב, כל פעם לפי משתנה אחר, עד שמגיעים ל“עלים” בהם ניתן החיזוי (כרגרסיה או סיווג).

עצים הם “נחמדים” כי הם מושכים ויזואלית ונוח לפרש אותם, אבל בדרך כלל הם לא נותנים תוצאות טובות, ומאוד רגישים לשינויים קלים (לדוגמה להוספה או החסרה של תצפיות). לכן, הכללות מקובלות לעצים הם כאלו המשכפלות את התהליך פעמים רבות (לדוגמה ל“יער אקראי”, שבו גם נדון ביחידה זו). שכפול התהליך לעצים מרובים הופך את המודל ליותר רובוסטי ויותר מדויק, אבל קצת פחות ברור לפרשנות.

library(tidyverse)

ggplot(diamonds, aes(y = price, x = carat)) + 
  facet_wrap(~ clarity) + 
  stat_smooth(method = "lm")

library(rpart)

diamond_price_tree <- rpart(formula = price ~ ., 
                            data = diamonds)

library(rpart.plot)
prp(diamond_price_tree)

diamond_price_tree
## n= 53940 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 53940 858473100000  3932.800  
##    2) carat< 0.995 34880  43459420000  1632.641  
##      4) y< 5.535 24951   6860691000  1058.546 *
##      5) y>=5.535 9929   7710112000  3075.309 *
##    3) carat>=0.995 19060 292761600000  8142.115  
##      6) y< 7.195 12884  60679350000  6137.844  
##       12) clarity=I1,SI2,SI1,VS2 9804  20256360000  5397.093 *
##       13) clarity=VS1,VVS2,VVS1,IF 3080  17919640000  8495.739 *
##      7) y>=7.195 6176  72354930000 12323.300  
##       14) y< 7.815 3945  33996520000 10899.960  
##         28) clarity=I1,SI2 954   3380193000  8375.178 *
##         29) clarity=SI1,VS2,VS1,VVS2,VVS1,IF 2991  22595360000 11705.260  
##           58) color=H,I,J 1554   5588830000 10014.970 *
##           59) color=D,E,F,G 1437   7765314000 13533.160 *
##       15) y>=7.815 2231  16233830000 14840.160 *
summary(diamond_price_tree)
## Call:
## rpart(formula = price ~ ., data = diamonds)
##   n= 53940 
## 
##           CP nsplit  rel error    xerror        xstd
## 1 0.60834997      0 1.00000000 1.0000821 0.008801006
## 2 0.18605980      1 0.39165003 0.3916954 0.003926316
## 3 0.03365116      2 0.20559023 0.2056335 0.002060823
## 4 0.02621322      3 0.17193908 0.1720624 0.002058132
## 5 0.02577201      4 0.14572586 0.1483153 0.001839371
## 6 0.01005400      5 0.11995385 0.1203474 0.001549010
## 7 0.01000000      7 0.09984584 0.1087609 0.001420495
## 
## Variable importance
##   carat       y       x       z clarity   color 
##      25      24      24      23       3       1 
## 
## Node number 1: 53940 observations,    complexity param=0.60835
##   mean=3932.8, MSE=1.591533e+07 
##   left son=2 (34880 obs) right son=3 (19060 obs)
##   Primary splits:
##       carat < 0.995 to the left,  improve=0.60835000, (0 missing)
##       y     < 6.345 to the left,  improve=0.60688030, (0 missing)
##       x     < 6.335 to the left,  improve=0.60323470, (0 missing)
##       z     < 3.915 to the left,  improve=0.59819760, (0 missing)
##       color splits as  LLLLRRR,   improve=0.02222238, (0 missing)
##   Surrogate splits:
##       x       < 6.275 to the left,  agree=0.984, adj=0.953, (0 split)
##       y       < 6.285 to the left,  agree=0.981, adj=0.946, (0 split)
##       z       < 3.895 to the left,  agree=0.978, adj=0.937, (0 split)
##       clarity splits as  RRLLLLLL,  agree=0.679, adj=0.091, (0 split)
##       color   splits as  LLLLLRR,   agree=0.660, adj=0.039, (0 split)
## 
## Node number 2: 34880 observations,    complexity param=0.03365116
##   mean=1632.641, MSE=1245969 
##   left son=4 (24951 obs) right son=5 (9929 obs)
##   Primary splits:
##       y       < 5.535 to the left,  improve=0.66472620, (0 missing)
##       carat   < 0.625 to the left,  improve=0.66450550, (0 missing)
##       x       < 5.485 to the left,  improve=0.66202330, (0 missing)
##       z       < 3.375 to the left,  improve=0.66100870, (0 missing)
##       clarity splits as  RRRLLLLL,  improve=0.01153592, (0 missing)
##   Surrogate splits:
##       x       < 5.535 to the left,  agree=0.992, adj=0.972, (0 split)
##       carat   < 0.635 to the left,  agree=0.991, adj=0.969, (0 split)
##       z       < 3.405 to the left,  agree=0.984, adj=0.942, (0 split)
##       clarity splits as  RRLLLLLL,  agree=0.726, adj=0.037, (0 split)
##       table   < 62.25 to the left,  agree=0.719, adj=0.012, (0 split)
## 
## Node number 3: 19060 observations,    complexity param=0.1860598
##   mean=8142.115, MSE=1.536e+07 
##   left son=6 (12884 obs) right son=7 (6176 obs)
##   Primary splits:
##       y       < 7.195 to the left,  improve=0.54558840, (0 missing)
##       x       < 7.195 to the left,  improve=0.53787920, (0 missing)
##       carat   < 1.495 to the left,  improve=0.53685280, (0 missing)
##       z       < 4.425 to the left,  improve=0.52466320, (0 missing)
##       clarity splits as  LLLRRRRR,  improve=0.05520053, (0 missing)
##   Surrogate splits:
##       x     < 7.185 to the left,  agree=0.984, adj=0.950, (0 split)
##       carat < 1.445 to the left,  agree=0.980, adj=0.938, (0 split)
##       z     < 4.405 to the left,  agree=0.965, adj=0.891, (0 split)
##       color splits as  LLLLLLR,   agree=0.679, adj=0.010, (0 split)
##       table < 67.5  to the left,  agree=0.676, adj=0.001, (0 split)
## 
## Node number 4: 24951 observations
##   mean=1058.546, MSE=274966.6 
## 
## Node number 5: 9929 observations
##   mean=3075.309, MSE=776524.5 
## 
## Node number 6: 12884 observations,    complexity param=0.02621322
##   mean=6137.844, MSE=4709667 
##   left son=12 (9804 obs) right son=13 (3080 obs)
##   Primary splits:
##       clarity splits as  LLLLRRRR,  improve=0.3708568, (0 missing)
##       y       < 6.775 to the left,  improve=0.1220791, (0 missing)
##       carat   < 1.175 to the left,  improve=0.1068320, (0 missing)
##       x       < 6.745 to the left,  improve=0.1059436, (0 missing)
##       color   splits as  RRRRLLL,   improve=0.1042738, (0 missing)
##   Surrogate splits:
##       table < 49.5  to the right, agree=0.761, adj=0.001, (0 split)
## 
## Node number 7: 6176 observations,    complexity param=0.02577201
##   mean=12323.3, MSE=1.17155e+07 
##   left son=14 (3945 obs) right son=15 (2231 obs)
##   Primary splits:
##       y       < 7.815 to the left,  improve=0.30577850, (0 missing)
##       x       < 7.845 to the left,  improve=0.29884120, (0 missing)
##       carat   < 1.915 to the left,  improve=0.29280650, (0 missing)
##       z       < 4.805 to the left,  improve=0.27696390, (0 missing)
##       clarity splits as  LRRRRRRR,  improve=0.07250237, (0 missing)
##   Surrogate splits:
##       x       < 7.845 to the left,  agree=0.983, adj=0.953, (0 split)
##       carat   < 1.825 to the left,  agree=0.971, adj=0.921, (0 split)
##       z       < 4.825 to the left,  agree=0.951, adj=0.866, (0 split)
##       clarity splits as  RRLLLLLL,  agree=0.677, adj=0.106, (0 split)
##       depth   < 57.05 to the right, agree=0.640, adj=0.003, (0 split)
## 
## Node number 12: 9804 observations
##   mean=5397.093, MSE=2066132 
## 
## Node number 13: 3080 observations
##   mean=8495.739, MSE=5818064 
## 
## Node number 14: 3945 observations,    complexity param=0.010054
##   mean=10899.96, MSE=8617622 
##   left son=28 (954 obs) right son=29 (2991 obs)
##   Primary splits:
##       clarity splits as  LLRRRRRR,  improve=0.23593490, (0 missing)
##       color   splits as  RRRRLLL,   improve=0.22434970, (0 missing)
##       y       < 7.565 to the left,  improve=0.03876909, (0 missing)
##       carat   < 1.615 to the left,  improve=0.03231217, (0 missing)
##       x       < 7.565 to the left,  improve=0.02811208, (0 missing)
##   Surrogate splits:
##       z     < 4.855 to the right, agree=0.769, adj=0.045, (0 split)
##       depth < 64.45 to the right, agree=0.768, adj=0.040, (0 split)
##       carat < 1.825 to the right, agree=0.767, adj=0.035, (0 split)
##       cut   splits as  LRRRR,     agree=0.764, adj=0.024, (0 split)
##       x     < 7.865 to the right, agree=0.760, adj=0.008, (0 split)
## 
## Node number 15: 2231 observations
##   mean=14840.16, MSE=7276482 
## 
## Node number 28: 954 observations
##   mean=8375.178, MSE=3543180 
## 
## Node number 29: 2991 observations,    complexity param=0.010054
##   mean=11705.26, MSE=7554451 
##   left son=58 (1554 obs) right son=59 (1437 obs)
##   Primary splits:
##       color   splits as  RRRRLLL,   improve=0.40898740, (0 missing)
##       clarity splits as  LLLRRRRR,  improve=0.09908389, (0 missing)
##       carat   < 1.615 to the left,  improve=0.05752247, (0 missing)
##       y       < 7.595 to the left,  improve=0.05099388, (0 missing)
##       z       < 4.715 to the left,  improve=0.04452641, (0 missing)
##   Surrogate splits:
##       carat < 1.545 to the right, agree=0.562, adj=0.089, (0 split)
##       z     < 4.595 to the right, agree=0.556, adj=0.076, (0 split)
##       y     < 7.325 to the right, agree=0.538, adj=0.038, (0 split)
##       x     < 7.435 to the right, agree=0.538, adj=0.038, (0 split)
##       depth < 60.85 to the right, agree=0.537, adj=0.037, (0 split)
## 
## Node number 58: 1554 observations
##   mean=10014.97, MSE=3596416 
## 
## Node number 59: 1437 observations
##   mean=13533.16, MSE=5403837

פרמטרים שונים שולטים על עומק העץ. ככל שעץ עמוק יותר, כך אנחנו ניכנס למצבים של Over-fitting. הנה גידול עץ עמוק במיוחד (וכנראה לא מאוד מועיל).

diamond_price_tree_large <- rpart(formula = price ~ ., 
                                  data = diamonds,
                                  control = rpart.control(cp = 0.0005, xval = 10))
prp(diamond_price_tree_large)

#summary(diamond_price_tree_large)

הפרמטר cp שאותו שינינו כדי לשלוט על גודל העץ הוא פרמטר מורכבות העץ (complexity parameter). הוא שולט על אלגוריתם הגידול של העץ. כאשר הפרמטר נמוך, האלגוריתם נוטה לבצע יותר פיצולים, וכאשר הפרמטר גבוה, ישנם פחות פיצולים. למעשה הפרמטר מציב רף לפיצול, ורק אם פיצול משפר את החיזוי בערכו של הפרמטר, אז מתבצע הפיצול.

למעשה, אפשר גם לגדל עץ עמוק ואז לגזום אותו prune, כדי לייצר עץ קטן בחזרה. באופן מסויים זה מזכיר את האלגוריתם של step wise selection (במודל של רגרסיה) שאותו הזכרנו ביחידה קודמת. שני האלגוריתמים “הולכים אחורה” ומנסים להוריד משתנים.

איך עובד האלגוריתם של גידול וגיזום עצים?

גידול וגיזום עצים

האלגוריתם מחלק את מרחב התצפיות \(X\) למלבנים, בכל מלבן מתבצע ממוצע של ערכי התצפיות \(y\), וזו התחזית הניתנת לתצפיות חדשות השייכות לאותו המלבן.

כלומר, האלגוריתם מנסה למזער את הגודל הבא:

\[\sum_{j=1}^J\sum_{i\in R_j}\left(y_i-\hat{y}_{R_j}\right)^2\]

כאשר \(J\) הוא מספר המלבנים אליהם מחולק מרחב ה-\(X\).

בכל רגע נתון באלגוריתם, נבחן הפיצול הטוב ביותר כרגע, כלומר, האלגוריתם מחפש את המשתנה ה-\(X_j\) והערך הקריטי \(s\), כך שיביא למינימום את הגודל:

\[\sum_{i: x_i\in R_1(j,s)}\left(y_i-\hat{y}_{R_1}\right)^2 + \sum_{i: x_i\in R_2(j,s)}\left(y_i-\hat{y}_{R_2}\right)^2\]

כאשר:

\[R_1(j,s) = \left\{X|X_j<s\right\} \text{ and } R_2(j,s) = \left\{X|X_j\geq s\right\}\]

במילים אחרות, מדובר באלגוריתם “חמדן” (הוא תמיד מחפש את הדבר הטוב ביותר כרגע, ולעיתים זה יכול להוביל לתוצאות לא טובות).

כדי לגזום עץ (להקטין את מידת הסיבוכיות שלו) ניתן להשתמש בפקודה prune.

diamond_price_pruned <- prune(diamond_price_tree_large, cp = 0.05)

prp(diamond_price_pruned)

שימוש ב-Cross validation לבחירת פרמטרים

כדי לבחון מה פרמטר ה-cp הרצוי, מומלץ להשתמש ב-cross validation.

מה עושה cross validation?

האלגוריתם של rpart למעשה עושה את כל זה עבורנו.

# here is the cp table
diamond_price_tree_large$cptable
##              CP nsplit  rel error     xerror         xstd
## 1  0.6083499681      0 1.00000000 1.00006475 0.0088010440
## 2  0.1860597998      1 0.39165003 0.39173358 0.0039266775
## 3  0.0336511552      2 0.20559023 0.20564109 0.0020608645
## 4  0.0262132207      3 0.17193908 0.17206300 0.0020582035
## 5  0.0257720103      4 0.14572586 0.14453736 0.0017669383
## 6  0.0100540028      5 0.11995385 0.12076185 0.0015626320
## 7  0.0059658108      7 0.09984584 0.09960697 0.0013044269
## 8  0.0059115945      8 0.09388003 0.09427032 0.0012511756
## 9  0.0050256230      9 0.08796843 0.08891238 0.0012003879
## 10 0.0039210262     10 0.08294281 0.08398700 0.0012026078
## 11 0.0034501246     11 0.07902179 0.08159645 0.0011398941
## 12 0.0032341022     12 0.07557166 0.07568072 0.0010737729
## 13 0.0031204544     13 0.07233756 0.07353013 0.0010593893
## 14 0.0028156712     14 0.06921710 0.07104537 0.0010334381
## 15 0.0020441493     15 0.06640143 0.06711342 0.0009792445
## 16 0.0017401225     16 0.06435728 0.06587602 0.0009567098
## 17 0.0016567915     17 0.06261716 0.06384242 0.0009320158
## 18 0.0011974445     19 0.05930358 0.06028437 0.0008966664
## 19 0.0011041126     20 0.05810613 0.05842785 0.0008714159
## 20 0.0010277728     21 0.05700202 0.05785724 0.0008638224
## 21 0.0009978567     22 0.05597425 0.05718099 0.0008425371
## 22 0.0009402965     23 0.05497639 0.05508649 0.0008258311
## 23 0.0008722432     24 0.05403610 0.05428277 0.0008136440
## 24 0.0008643220     25 0.05316385 0.05412565 0.0008082537
## 25 0.0008225882     26 0.05229953 0.05368813 0.0008048566
## 26 0.0007640001     28 0.05065435 0.05195502 0.0007717015
## 27 0.0007258915     29 0.04989035 0.05106837 0.0007595777
## 28 0.0006199210     30 0.04916446 0.05038140 0.0007519656
## 29 0.0006127750     31 0.04854454 0.04934351 0.0007385028
## 30 0.0005904991     32 0.04793177 0.04903457 0.0007344071
## 31 0.0005707747     33 0.04734127 0.04845750 0.0007182557
## 32 0.0005611596     34 0.04677049 0.04775787 0.0007063952
## 33 0.0005606257     35 0.04620933 0.04734673 0.0007033127
## 34 0.0005498450     36 0.04564871 0.04715392 0.0007023142
## 35 0.0005170202     37 0.04509886 0.04650942 0.0006829172
## 36 0.0005000000     38 0.04458184 0.04586644 0.0006783903
# the shortest way - use a predefined function to plot the xval cp errors
rpart::plotcp(diamond_price_tree_large)

במקרה זה, היות שיש לנו מדגם מאוד גדול של יהלומים, השגיאה אכן קטנה כאשר הפרמטר קטן, ואנחנו עוד לא בתחום של overfitting. במקרים אחרים ייתכן שנראה גרף שאינו מונוטוני, כלומר כאשר נקטין את cp, בשלב מסוים נקבל מצב שבו השגיאה גדלה.

כמו כן, ניתן לראות שהתפוקה השולית של הקטנת ה-cp פוחתת.

עד כה, דנו בעץ רגרסיה - עץ המספק חיזוי לערכים רציפים. ישנם עצי סיווג הרלוונטיים במקרים בהם נדרש לסווג תצפיות לערכים בדידים (classification), בדומה לאלגוריתמים אחרים שדנו בהם ביחידות קודמות (knn, רגרסיה לוגיסטית, lda, ו-qda).

במקרה של עצי סיווג, לא משתמשים בשגיאת RSS אלא במדד אחר הנקרא Gini index.

\[G = \sum_{k=1}^K\hat{p}_{mk}(1-\hat{p}_{mk})\]

כאשר \(\hat{p}_{mk}\) הוא הפרופורציה של תצפיות במרחב ה-\(m\), עם סיווג \(k\). מדד זה נמוך ככל שערכי \(\hat{p}_{mk}\) הם יותר קיצוניים (קרובים ל-0 או ל-1), קרי העלים של העץ “טהורים”.

ggplot(tibble(p = seq(0, 1, 0.01)), aes(x = p, y = p*(1-p))) + 
  geom_line() + 
  ylab("G = p*(1-p)") +
  ggtitle("Illustration: Gini index will be minimized when p=1 or p=0")

בחלק מהאלגוריתמים נעשה שימוש במדד מקביל למדד Gini: מדד האנטרופיה.

\[D = -\sum_{k=1}^K{\hat{p}_{mk}\log\hat{p}_{mk}}\]

תרגיל

בתרגיל הבא נשתמש בעצי החלטה כדי לחזות את הסבירות לנטישה של לקוח.

  1. קראו את הקובץ WA_Fn-UseC_-Telco-Customer-Churn.csv.
  2. בנו מודל עץ לחיזוי הנטישה, השתמשו ב-cp קטן וב-cp גדול.
  3. כעת ציירו את שני העצים, האם אתם מצליחים להפיק תובנות כלשהן מהעצים?
  4. הציגו את שגיאת ה-cross validation כפונקציה של cp. מה ערך cp שלדעתכם נכון לבחור?
  5. כעת השתמשו בערך cp שקיבלתם בסעיף הקודם. חלקו את הנתונים ל-train/test והשתמשו בנתוני ה-test כדי לייצר תרשים ROC (באפשרותכם להיעזר ביחידה שעסקה ברגרסיה לוגיסטית על מנת להיזכר בקוד הרלוונטי).
  6. התאימו מודל רגרסיה לוגיסטית מתחרה, ובצעו השוואה של לתוצאותיו למול תוצאות העץ (ציירו את שני ה-ROC עבור האלגוריתמים שבחרתם). האם יש אלגוריתם שניתן לומר שביצועיו טובים יותר?

לסיכום

  • עצי החלטה קל להסביר למקבלי החלטות
  • יכולים לבטא קשרים שאינם ניתנים לביטוי בנוסחת רגרסיה
  • ניתן בקלות להשתמש במשתנים קטגוריאליים, וגם להתמודד עם ערכים חסרים

אבל

  • כמודל חיזוי - האם לא כל כך טובים
  • הם לא רובוסטיים (שינוי קל בנתונים עלול להביא לשינוי מאוד מהותי במבנה העץ)

לכן, נבנו מספר אלגוריתמים המבוססים על עצים אך הם יותר רובוסטים ולרוב בעלי ביצועים טובים יותר.

  • יערות אקראיים - randomForests
  • Bagging
  • Boosting

יערות אקראיים (randomForests)

האלגוריתם של יערות אקראיים בונה מקבץ של של עצים רבים, כאשר בכל עץ בכל פיצול, הוא מגביל את המשתנים לפיהם הוא יכול לפצל ל-\(m\) משתנים בלבד (מתוך \(p\) אפשריים). בדרך כלל \(m\approx \sqrt{p}\)).

כמו כן, האלגוריתם מגריל תצפיות (במקום להשתמש בכל התצפיות הוא משתמש במדגם שלהן), לצורך בנייתו של עץ.

אלגוריתם זה יכול לצמצם את ההשפעות של קורלציה בין משתנים, וכמו כן, הוא נותן הזדמנות למשתנים מסבירים שונים לבוא לידי ביטוי, אפילו אם הם לא בעלי העוצמה החזקה ביותר.

לבסוף התוצר המתקבל הוא ממוצע החיזויים על פני כלל העצים.

ערך מוסף בעצים הוא שניתן לחשב את ההפחתה הממוצעת במדד Gini של כל אחד מהמשתנים המסבירים, וזה מאפשר לדרג אותם לפי סדר חשיבות. בבעיית רגרסיה, החשיבות מסודרת לפי מידת “הטהורות” (במונחי RSS) שהוספת משתנה מסוים תרמה לדיוק, בממוצע.

library(randomForest)
## randomForest 4.6-12
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:dplyr':
## 
##     combine
## The following object is masked from 'package:ggplot2':
## 
##     margin
# note the use of maxnodes, otherwise the trees are grown to maximal size
# also limiting the number of trees to 150 - the default is 500...
diamond_price_forest <- randomForest(
  formula = price ~ .,
  data = diamonds,
  maxnodes = 15,
  ntree = 150)

# plot the importance plot
varImpPlot(diamond_price_forest)

# show an example of the first tree
getTree(diamond_price_forest, k = 1)
##    left daughter right daughter split var split point status prediction
## 1              2              3         7       6.335     -3  3927.2732
## 2              4              5         1       0.625     -3  1697.8672
## 3              6              7         5      59.650     -3  8263.3924
## 4              8              9         7       4.965     -3  1048.4733
## 5             10             11         8       6.085     -3  3209.1168
## 6             12             13         9       4.275     -3  7528.9389
## 7             14             15         1       1.475     -3  8348.2489
## 8             16             17         1       0.375     -3   784.5946
## 9             18             19         3       4.500     -3  1675.8869
## 10            20             21         4       3.500     -3  2757.5049
## 11            22             23         7       6.265     -3  4291.0948
## 12            24             25         4       3.500     -3  5427.3591
## 13            26             27         9       4.615     -3 12049.7791
## 14            28             29         8       6.775     -3  6274.9434
## 15             0              0         0       0.000     -1 12242.9665
## 16             0              0         0       0.000     -1   702.0850
## 17             0              0         0       0.000     -1   971.1848
## 18             0              0         0       0.000     -1  1735.8274
## 19             0              0         0       0.000     -1  1382.9298
## 20             0              0         0       0.000     -1  2494.9158
## 21             0              0         0       0.000     -1  3023.1710
## 22             0              0         0       0.000     -1  4086.6390
## 23             0              0         0       0.000     -1  4716.7947
## 24             0              0         0       0.000     -1  4475.0434
## 25             0              0         0       0.000     -1  7002.1926
## 26             0              0         0       0.000     -1 10158.5433
## 27             0              0         0       0.000     -1 14422.6779
## 28             0              0         0       0.000     -1  5777.2421
## 29             0              0         0       0.000     -1  7495.8592

תרגיל

  1. בנו יער אקראי לנתוני הנטישה של telco.
  2. חשבו על ה-test set את שיעור הטעות מסוג ראשון ושיעור הטעות מסוג שני.
  3. הוסיפו את הנתונים של יער זה לעקומות ה-ROC מהסעיף הקודם.
    1. שימו לב, כאשר אתם משתמשים בפונקציית predict, עליכם להגדיר את הפרמטר type בצורה מסויימת. איך?
  4. בנו תרשים של חשיבות המשתנים. מה המשתנה/ים החשוב/ים ביותר בהשפעה על נטישת/נאמנות לקוחות?
  5. מנכ“ל החברה מתלבט האם להציע הנחה ללקוחות אשר המודל צופה שינטשו. סמנכ”ל הכספים טוען שחבל להציע הנחה, משתנה זה אינו משמעותי מספיק בשביל שהנחה במחיר תצליח לגרום ללקוחות להישאר. מאידך סמנכ“ל שירות הלקוחות טוען שהנחה תעזור מאוד. מנכ”ל החברה ביקש מכם לשפוט - מי מהם צודק? הציעו מודל שיבחן את ההשפעה של הנחה ללקוחות מסוימים כדאי לצמצם את הנטישה. האם כדי לנקוט בטקטיקת מתן הנחות לצורך גידול בפדיון? השתמשו במשתנה monthlycharges כדי לקבוע את שיעור ההנחה (אחוז מתוך משתנה זה, ללקוחות הרלוונטיים).
# some help: 
# first if you try to replicate the code from the diamond's example, you will get an error.
# this is because randomForest expects no character variables, just numeric and factors.
# So how can we turn everything character into factor?
# Intuitively, you would probably do:

# telco_churn <- telco_churn %>%
#    mutate(gender = as.factor(gender),
#           SeniorCitizen = as.factor(SeniorCitizen),
#           ...)

# but this is like doing the same action over and over again.
# wouldn't it be nice to just loop over everything, and if a column is of the wrong type (character)
# just convert it into factor?
# mutate_if() does exactly that.
# it needs a condition called .predicate, and a vector function to operate called .funs. 
# It goes like this:

telco_churn <- telco_churn %>%
   mutate_if(.predicate = funs(typeof(.)=="character"), 
             .funs = funs(as.factor(.)))

# What does it do?
# it's like looping
# for (i in 1:NCOL(telco_churn)){
#    if (typeof(telco_churn[,i]) == "character") {
#       telco_churn[,i] <- as.factor(telco_churn[,i])
#    }
# }

# the syntax of 
# typeof(.) == "character"
# is like "replace . with the vector you are currently checking"
# the syntax of as.factor(.) is like
# "replace . with the vector" that is a character and you need to type case into a factor
# the funs() function makes the expression explicit after replacing the . with the current vector

# now you can continue the exercise...

גישת ה-Boosting

ביערות אקראיים ראינו כיצד חזאי בודד (עץ) משתכפל והופך לשילוב של הרבה חזאים. כאשר מבוצע שילוב של חזאים רבים הדבר מהווה פוטנציאל להפחית שגיאות שעשויות להופיע באופן מקומי בחזאים אשר נתונים פעמים רבות ל“גחמות הסטטיסטיקה” (או שגיאות הנובעות ממינימום מקומי או מהתאמת יתר).

גישה נוספת חוץ מיערות אקראיים היא גישת ה-Boosting. היא יכולה להתאים לגישות שונות (לאו דווקא להכללה של עצים), אך פה נדגים אותה בהקשר העצים.

נניח שהבעיה שלנו היא בעיית רגרסיה (חיזוי ערך של משתנה רציף). ב-Boosting, בכל שלב האלגוריתם יבנה עץ, שהמטרה שלו היא חיזוי השגיאה (לא הערך האמיתי של \(y\) אלא השגיאה הצפוייה בהתבסס על כל העצים שנבנו עד כה).

המודל מתווסף כסכום לכל יתר המודל שחושבו עד כה, עם פרמטר “הקטנה” \(\lambda\).

במילים אחרות, האלגוריתם מרכיב סכום של הרבה עצים קטנים, כשכל פעם הוא נותן דגש על צמצום השגיאות שהתקבלו עד כה.

Pseudo code:
Set r = y, f(x)=0
For b=1, 2,..., B repeat:
   Fit a tree, f_curr, to the data (X,r)
   Update f by adding current learned tree: 
      f <- f + lambda*f_curr
   Update the residuals
      r - r-lambda*f_curr
Output f

במקרה של בעיות סיווג, העדכון של המודלים מתבצע על ידי בניה בכל שלב של מודל סיווג (לדוגמה עץ), תוך כדי מתן משקל גדול יותר לתצפיות אשר הסיווג שלהן שגוי.

ב-R יש שתי חבילות המשמשות לboosting:

תרגיל

השתמשו בחבילת xgboost, בפקודת xgboost כדי לייצר חזאי לנטישת לקוחות. הצגת התוצאות בתרשים ה-ROC. האם ביצוע boosting שיפר את החזאים?

שימו לב, פקודת xgboost דורשת הכנה של מבנה הנתונים למטריצה. הנה קוד שיסייע לכם בהכנת המטריצה. השלימו את ה-XXX עם הפקודות / שמות משתנים המתאימים. שימו לב לשימוש ב-mutate_all, הדומה לפקודה שהשתמשנו בה קודם mutate_if.

# prepare the data
telco_churn_for_boost <- telco_churn %>%
  filter(is_train) %>%
  select(XXX:XXX)

dtrain <- xgb.DMatrix(telco_churn_for_boost %>%
                        mutate_all(funs(as.numeric(XXX))) %>%
                        select(-XXX) %>%
                        as.matrix(), 
                      label = XXX == XXX)

# building the boost predictor
churn_boost <- xgboost::xgboost(data = dtrain, 
                                nrounds = XXX, 
                                params = 
                                  list(objective=XXX,
                                       booster=XXX))

האם יש פרמטרים של העצים הנבנים במהלך אלגוריתם ה-boosting שאתם יכולים לשנות, כך שישפרו את החזאי?