ביחידה זו נלמד על שתי שיטות נוספות המשמשות לחיזוי ורגרסיה: עצים, ויערות אקראיים (trees and random forests).
תזכורת: באחת היחידות הקודמות התאמנו מודל רגרסיה לינארית למחיר יהלומים. כאשר פיצלנו את המודל לשני מודלים העובדים כל אחד על תחום של משתנה קטגורי, איכות החיזוי של המודל השתפרה.
הרעיון הכללי של עצים הוא דומה - לפצל את המרחב, כל פעם לפי משתנה אחר, עד שמגיעים ל“עלים” בהם ניתן החיזוי (כרגרסיה או סיווג).
עצים הם “נחמדים” כי הם מושכים ויזואלית ונוח לפרש אותם, אבל בדרך כלל הם לא נותנים תוצאות טובות, ומאוד רגישים לשינויים קלים (לדוגמה להוספה או החסרה של תצפיות). לכן, הכללות מקובלות לעצים הם כאלו המשכפלות את התהליך פעמים רבות (לדוגמה ל“יער אקראי”, שבו גם נדון ביחידה זו). שכפול התהליך לעצים מרובים הופך את המודל ליותר רובוסטי ויותר מדויק, אבל קצת פחות ברור לפרשנות.
library(tidyverse)
ggplot(diamonds, aes(y = price, x = carat)) +
facet_wrap(~ clarity) +
stat_smooth(method = "lm")
library(rpart)
diamond_price_tree <- rpart(formula = price ~ .,
data = diamonds)
library(rpart.plot)
prp(diamond_price_tree)
diamond_price_tree
## n= 53940
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 53940 858473100000 3932.800
## 2) carat< 0.995 34880 43459420000 1632.641
## 4) y< 5.535 24951 6860691000 1058.546 *
## 5) y>=5.535 9929 7710112000 3075.309 *
## 3) carat>=0.995 19060 292761600000 8142.115
## 6) y< 7.195 12884 60679350000 6137.844
## 12) clarity=I1,SI2,SI1,VS2 9804 20256360000 5397.093 *
## 13) clarity=VS1,VVS2,VVS1,IF 3080 17919640000 8495.739 *
## 7) y>=7.195 6176 72354930000 12323.300
## 14) y< 7.815 3945 33996520000 10899.960
## 28) clarity=I1,SI2 954 3380193000 8375.178 *
## 29) clarity=SI1,VS2,VS1,VVS2,VVS1,IF 2991 22595360000 11705.260
## 58) color=H,I,J 1554 5588830000 10014.970 *
## 59) color=D,E,F,G 1437 7765314000 13533.160 *
## 15) y>=7.815 2231 16233830000 14840.160 *
summary(diamond_price_tree)
## Call:
## rpart(formula = price ~ ., data = diamonds)
## n= 53940
##
## CP nsplit rel error xerror xstd
## 1 0.60834997 0 1.00000000 1.0000821 0.008801006
## 2 0.18605980 1 0.39165003 0.3916954 0.003926316
## 3 0.03365116 2 0.20559023 0.2056335 0.002060823
## 4 0.02621322 3 0.17193908 0.1720624 0.002058132
## 5 0.02577201 4 0.14572586 0.1483153 0.001839371
## 6 0.01005400 5 0.11995385 0.1203474 0.001549010
## 7 0.01000000 7 0.09984584 0.1087609 0.001420495
##
## Variable importance
## carat y x z clarity color
## 25 24 24 23 3 1
##
## Node number 1: 53940 observations, complexity param=0.60835
## mean=3932.8, MSE=1.591533e+07
## left son=2 (34880 obs) right son=3 (19060 obs)
## Primary splits:
## carat < 0.995 to the left, improve=0.60835000, (0 missing)
## y < 6.345 to the left, improve=0.60688030, (0 missing)
## x < 6.335 to the left, improve=0.60323470, (0 missing)
## z < 3.915 to the left, improve=0.59819760, (0 missing)
## color splits as LLLLRRR, improve=0.02222238, (0 missing)
## Surrogate splits:
## x < 6.275 to the left, agree=0.984, adj=0.953, (0 split)
## y < 6.285 to the left, agree=0.981, adj=0.946, (0 split)
## z < 3.895 to the left, agree=0.978, adj=0.937, (0 split)
## clarity splits as RRLLLLLL, agree=0.679, adj=0.091, (0 split)
## color splits as LLLLLRR, agree=0.660, adj=0.039, (0 split)
##
## Node number 2: 34880 observations, complexity param=0.03365116
## mean=1632.641, MSE=1245969
## left son=4 (24951 obs) right son=5 (9929 obs)
## Primary splits:
## y < 5.535 to the left, improve=0.66472620, (0 missing)
## carat < 0.625 to the left, improve=0.66450550, (0 missing)
## x < 5.485 to the left, improve=0.66202330, (0 missing)
## z < 3.375 to the left, improve=0.66100870, (0 missing)
## clarity splits as RRRLLLLL, improve=0.01153592, (0 missing)
## Surrogate splits:
## x < 5.535 to the left, agree=0.992, adj=0.972, (0 split)
## carat < 0.635 to the left, agree=0.991, adj=0.969, (0 split)
## z < 3.405 to the left, agree=0.984, adj=0.942, (0 split)
## clarity splits as RRLLLLLL, agree=0.726, adj=0.037, (0 split)
## table < 62.25 to the left, agree=0.719, adj=0.012, (0 split)
##
## Node number 3: 19060 observations, complexity param=0.1860598
## mean=8142.115, MSE=1.536e+07
## left son=6 (12884 obs) right son=7 (6176 obs)
## Primary splits:
## y < 7.195 to the left, improve=0.54558840, (0 missing)
## x < 7.195 to the left, improve=0.53787920, (0 missing)
## carat < 1.495 to the left, improve=0.53685280, (0 missing)
## z < 4.425 to the left, improve=0.52466320, (0 missing)
## clarity splits as LLLRRRRR, improve=0.05520053, (0 missing)
## Surrogate splits:
## x < 7.185 to the left, agree=0.984, adj=0.950, (0 split)
## carat < 1.445 to the left, agree=0.980, adj=0.938, (0 split)
## z < 4.405 to the left, agree=0.965, adj=0.891, (0 split)
## color splits as LLLLLLR, agree=0.679, adj=0.010, (0 split)
## table < 67.5 to the left, agree=0.676, adj=0.001, (0 split)
##
## Node number 4: 24951 observations
## mean=1058.546, MSE=274966.6
##
## Node number 5: 9929 observations
## mean=3075.309, MSE=776524.5
##
## Node number 6: 12884 observations, complexity param=0.02621322
## mean=6137.844, MSE=4709667
## left son=12 (9804 obs) right son=13 (3080 obs)
## Primary splits:
## clarity splits as LLLLRRRR, improve=0.3708568, (0 missing)
## y < 6.775 to the left, improve=0.1220791, (0 missing)
## carat < 1.175 to the left, improve=0.1068320, (0 missing)
## x < 6.745 to the left, improve=0.1059436, (0 missing)
## color splits as RRRRLLL, improve=0.1042738, (0 missing)
## Surrogate splits:
## table < 49.5 to the right, agree=0.761, adj=0.001, (0 split)
##
## Node number 7: 6176 observations, complexity param=0.02577201
## mean=12323.3, MSE=1.17155e+07
## left son=14 (3945 obs) right son=15 (2231 obs)
## Primary splits:
## y < 7.815 to the left, improve=0.30577850, (0 missing)
## x < 7.845 to the left, improve=0.29884120, (0 missing)
## carat < 1.915 to the left, improve=0.29280650, (0 missing)
## z < 4.805 to the left, improve=0.27696390, (0 missing)
## clarity splits as LRRRRRRR, improve=0.07250237, (0 missing)
## Surrogate splits:
## x < 7.845 to the left, agree=0.983, adj=0.953, (0 split)
## carat < 1.825 to the left, agree=0.971, adj=0.921, (0 split)
## z < 4.825 to the left, agree=0.951, adj=0.866, (0 split)
## clarity splits as RRLLLLLL, agree=0.677, adj=0.106, (0 split)
## depth < 57.05 to the right, agree=0.640, adj=0.003, (0 split)
##
## Node number 12: 9804 observations
## mean=5397.093, MSE=2066132
##
## Node number 13: 3080 observations
## mean=8495.739, MSE=5818064
##
## Node number 14: 3945 observations, complexity param=0.010054
## mean=10899.96, MSE=8617622
## left son=28 (954 obs) right son=29 (2991 obs)
## Primary splits:
## clarity splits as LLRRRRRR, improve=0.23593490, (0 missing)
## color splits as RRRRLLL, improve=0.22434970, (0 missing)
## y < 7.565 to the left, improve=0.03876909, (0 missing)
## carat < 1.615 to the left, improve=0.03231217, (0 missing)
## x < 7.565 to the left, improve=0.02811208, (0 missing)
## Surrogate splits:
## z < 4.855 to the right, agree=0.769, adj=0.045, (0 split)
## depth < 64.45 to the right, agree=0.768, adj=0.040, (0 split)
## carat < 1.825 to the right, agree=0.767, adj=0.035, (0 split)
## cut splits as LRRRR, agree=0.764, adj=0.024, (0 split)
## x < 7.865 to the right, agree=0.760, adj=0.008, (0 split)
##
## Node number 15: 2231 observations
## mean=14840.16, MSE=7276482
##
## Node number 28: 954 observations
## mean=8375.178, MSE=3543180
##
## Node number 29: 2991 observations, complexity param=0.010054
## mean=11705.26, MSE=7554451
## left son=58 (1554 obs) right son=59 (1437 obs)
## Primary splits:
## color splits as RRRRLLL, improve=0.40898740, (0 missing)
## clarity splits as LLLRRRRR, improve=0.09908389, (0 missing)
## carat < 1.615 to the left, improve=0.05752247, (0 missing)
## y < 7.595 to the left, improve=0.05099388, (0 missing)
## z < 4.715 to the left, improve=0.04452641, (0 missing)
## Surrogate splits:
## carat < 1.545 to the right, agree=0.562, adj=0.089, (0 split)
## z < 4.595 to the right, agree=0.556, adj=0.076, (0 split)
## y < 7.325 to the right, agree=0.538, adj=0.038, (0 split)
## x < 7.435 to the right, agree=0.538, adj=0.038, (0 split)
## depth < 60.85 to the right, agree=0.537, adj=0.037, (0 split)
##
## Node number 58: 1554 observations
## mean=10014.97, MSE=3596416
##
## Node number 59: 1437 observations
## mean=13533.16, MSE=5403837
פרמטרים שונים שולטים על עומק העץ. ככל שעץ עמוק יותר, כך אנחנו ניכנס למצבים של Over-fitting. הנה גידול עץ עמוק במיוחד (וכנראה לא מאוד מועיל).
diamond_price_tree_large <- rpart(formula = price ~ .,
data = diamonds,
control = rpart.control(cp = 0.0005, xval = 10))
prp(diamond_price_tree_large)
#summary(diamond_price_tree_large)
הפרמטר cp שאותו שינינו כדי לשלוט על גודל העץ הוא פרמטר מורכבות העץ (complexity parameter). הוא שולט על אלגוריתם הגידול של העץ. כאשר הפרמטר נמוך, האלגוריתם נוטה לבצע יותר פיצולים, וכאשר הפרמטר גבוה, ישנם פחות פיצולים. למעשה הפרמטר מציב רף לפיצול, ורק אם פיצול משפר את החיזוי בערכו של הפרמטר, אז מתבצע הפיצול.
למעשה, אפשר גם לגדל עץ עמוק ואז לגזום אותו prune, כדי לייצר עץ קטן בחזרה. באופן מסויים זה מזכיר את האלגוריתם של step wise selection (במודל של רגרסיה) שאותו הזכרנו ביחידה קודמת. שני האלגוריתמים “הולכים אחורה” ומנסים להוריד משתנים.
איך עובד האלגוריתם של גידול וגיזום עצים?
האלגוריתם מחלק את מרחב התצפיות \(X\) למלבנים, בכל מלבן מתבצע ממוצע של ערכי התצפיות \(y\), וזו התחזית הניתנת לתצפיות חדשות השייכות לאותו המלבן.
כלומר, האלגוריתם מנסה למזער את הגודל הבא:
\[\sum_{j=1}^J\sum_{i\in R_j}\left(y_i-\hat{y}_{R_j}\right)^2\]
כאשר \(J\) הוא מספר המלבנים אליהם מחולק מרחב ה-\(X\).
בכל רגע נתון באלגוריתם, נבחן הפיצול הטוב ביותר כרגע, כלומר, האלגוריתם מחפש את המשתנה ה-\(X_j\) והערך הקריטי \(s\), כך שיביא למינימום את הגודל:
\[\sum_{i: x_i\in R_1(j,s)}\left(y_i-\hat{y}_{R_1}\right)^2 + \sum_{i: x_i\in R_2(j,s)}\left(y_i-\hat{y}_{R_2}\right)^2\]
כאשר:
\[R_1(j,s) = \left\{X|X_j<s\right\} \text{ and } R_2(j,s) = \left\{X|X_j\geq s\right\}\]
במילים אחרות, מדובר באלגוריתם “חמדן” (הוא תמיד מחפש את הדבר הטוב ביותר כרגע, ולעיתים זה יכול להוביל לתוצאות לא טובות).
כדי לגזום עץ (להקטין את מידת הסיבוכיות שלו) ניתן להשתמש בפקודה prune.
diamond_price_pruned <- prune(diamond_price_tree_large, cp = 0.05)
prp(diamond_price_pruned)
כדי לבחון מה פרמטר ה-cp הרצוי, מומלץ להשתמש ב-cross validation.
מה עושה cross validation?
האלגוריתם של rpart למעשה עושה את כל זה עבורנו.
# here is the cp table
diamond_price_tree_large$cptable
## CP nsplit rel error xerror xstd
## 1 0.6083499681 0 1.00000000 1.00006475 0.0088010440
## 2 0.1860597998 1 0.39165003 0.39173358 0.0039266775
## 3 0.0336511552 2 0.20559023 0.20564109 0.0020608645
## 4 0.0262132207 3 0.17193908 0.17206300 0.0020582035
## 5 0.0257720103 4 0.14572586 0.14453736 0.0017669383
## 6 0.0100540028 5 0.11995385 0.12076185 0.0015626320
## 7 0.0059658108 7 0.09984584 0.09960697 0.0013044269
## 8 0.0059115945 8 0.09388003 0.09427032 0.0012511756
## 9 0.0050256230 9 0.08796843 0.08891238 0.0012003879
## 10 0.0039210262 10 0.08294281 0.08398700 0.0012026078
## 11 0.0034501246 11 0.07902179 0.08159645 0.0011398941
## 12 0.0032341022 12 0.07557166 0.07568072 0.0010737729
## 13 0.0031204544 13 0.07233756 0.07353013 0.0010593893
## 14 0.0028156712 14 0.06921710 0.07104537 0.0010334381
## 15 0.0020441493 15 0.06640143 0.06711342 0.0009792445
## 16 0.0017401225 16 0.06435728 0.06587602 0.0009567098
## 17 0.0016567915 17 0.06261716 0.06384242 0.0009320158
## 18 0.0011974445 19 0.05930358 0.06028437 0.0008966664
## 19 0.0011041126 20 0.05810613 0.05842785 0.0008714159
## 20 0.0010277728 21 0.05700202 0.05785724 0.0008638224
## 21 0.0009978567 22 0.05597425 0.05718099 0.0008425371
## 22 0.0009402965 23 0.05497639 0.05508649 0.0008258311
## 23 0.0008722432 24 0.05403610 0.05428277 0.0008136440
## 24 0.0008643220 25 0.05316385 0.05412565 0.0008082537
## 25 0.0008225882 26 0.05229953 0.05368813 0.0008048566
## 26 0.0007640001 28 0.05065435 0.05195502 0.0007717015
## 27 0.0007258915 29 0.04989035 0.05106837 0.0007595777
## 28 0.0006199210 30 0.04916446 0.05038140 0.0007519656
## 29 0.0006127750 31 0.04854454 0.04934351 0.0007385028
## 30 0.0005904991 32 0.04793177 0.04903457 0.0007344071
## 31 0.0005707747 33 0.04734127 0.04845750 0.0007182557
## 32 0.0005611596 34 0.04677049 0.04775787 0.0007063952
## 33 0.0005606257 35 0.04620933 0.04734673 0.0007033127
## 34 0.0005498450 36 0.04564871 0.04715392 0.0007023142
## 35 0.0005170202 37 0.04509886 0.04650942 0.0006829172
## 36 0.0005000000 38 0.04458184 0.04586644 0.0006783903
# the shortest way - use a predefined function to plot the xval cp errors
rpart::plotcp(diamond_price_tree_large)
במקרה זה, היות שיש לנו מדגם מאוד גדול של יהלומים, השגיאה אכן קטנה כאשר הפרמטר קטן, ואנחנו עוד לא בתחום של overfitting. במקרים אחרים ייתכן שנראה גרף שאינו מונוטוני, כלומר כאשר נקטין את cp, בשלב מסוים נקבל מצב שבו השגיאה גדלה.
כמו כן, ניתן לראות שהתפוקה השולית של הקטנת ה-cp פוחתת.
עד כה, דנו בעץ רגרסיה - עץ המספק חיזוי לערכים רציפים. ישנם עצי סיווג הרלוונטיים במקרים בהם נדרש לסווג תצפיות לערכים בדידים (classification), בדומה לאלגוריתמים אחרים שדנו בהם ביחידות קודמות (knn, רגרסיה לוגיסטית, lda, ו-qda).
במקרה של עצי סיווג, לא משתמשים בשגיאת RSS אלא במדד אחר הנקרא Gini index.
\[G = \sum_{k=1}^K\hat{p}_{mk}(1-\hat{p}_{mk})\]
כאשר \(\hat{p}_{mk}\) הוא הפרופורציה של תצפיות במרחב ה-\(m\), עם סיווג \(k\). מדד זה נמוך ככל שערכי \(\hat{p}_{mk}\) הם יותר קיצוניים (קרובים ל-0 או ל-1), קרי העלים של העץ “טהורים”.
ggplot(tibble(p = seq(0, 1, 0.01)), aes(x = p, y = p*(1-p))) +
geom_line() +
ylab("G = p*(1-p)") +
ggtitle("Illustration: Gini index will be minimized when p=1 or p=0")
בחלק מהאלגוריתמים נעשה שימוש במדד מקביל למדד Gini: מדד האנטרופיה.
\[D = -\sum_{k=1}^K{\hat{p}_{mk}\log\hat{p}_{mk}}\]
בתרגיל הבא נשתמש בעצי החלטה כדי לחזות את הסבירות לנטישה של לקוח.
אבל
לכן, נבנו מספר אלגוריתמים המבוססים על עצים אך הם יותר רובוסטים ולרוב בעלי ביצועים טובים יותר.
האלגוריתם של יערות אקראיים בונה מקבץ של של עצים רבים, כאשר בכל עץ בכל פיצול, הוא מגביל את המשתנים לפיהם הוא יכול לפצל ל-\(m\) משתנים בלבד (מתוך \(p\) אפשריים). בדרך כלל \(m\approx \sqrt{p}\)).
כמו כן, האלגוריתם מגריל תצפיות (במקום להשתמש בכל התצפיות הוא משתמש במדגם שלהן), לצורך בנייתו של עץ.
אלגוריתם זה יכול לצמצם את ההשפעות של קורלציה בין משתנים, וכמו כן, הוא נותן הזדמנות למשתנים מסבירים שונים לבוא לידי ביטוי, אפילו אם הם לא בעלי העוצמה החזקה ביותר.
לבסוף התוצר המתקבל הוא ממוצע החיזויים על פני כלל העצים.
ערך מוסף בעצים הוא שניתן לחשב את ההפחתה הממוצעת במדד Gini של כל אחד מהמשתנים המסבירים, וזה מאפשר לדרג אותם לפי סדר חשיבות. בבעיית רגרסיה, החשיבות מסודרת לפי מידת “הטהורות” (במונחי RSS) שהוספת משתנה מסוים תרמה לדיוק, בממוצע.
library(randomForest)
## randomForest 4.6-12
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:dplyr':
##
## combine
## The following object is masked from 'package:ggplot2':
##
## margin
# note the use of maxnodes, otherwise the trees are grown to maximal size
# also limiting the number of trees to 150 - the default is 500...
diamond_price_forest <- randomForest(
formula = price ~ .,
data = diamonds,
maxnodes = 15,
ntree = 150)
# plot the importance plot
varImpPlot(diamond_price_forest)
# show an example of the first tree
getTree(diamond_price_forest, k = 1)
## left daughter right daughter split var split point status prediction
## 1 2 3 7 6.335 -3 3927.2732
## 2 4 5 1 0.625 -3 1697.8672
## 3 6 7 5 59.650 -3 8263.3924
## 4 8 9 7 4.965 -3 1048.4733
## 5 10 11 8 6.085 -3 3209.1168
## 6 12 13 9 4.275 -3 7528.9389
## 7 14 15 1 1.475 -3 8348.2489
## 8 16 17 1 0.375 -3 784.5946
## 9 18 19 3 4.500 -3 1675.8869
## 10 20 21 4 3.500 -3 2757.5049
## 11 22 23 7 6.265 -3 4291.0948
## 12 24 25 4 3.500 -3 5427.3591
## 13 26 27 9 4.615 -3 12049.7791
## 14 28 29 8 6.775 -3 6274.9434
## 15 0 0 0 0.000 -1 12242.9665
## 16 0 0 0 0.000 -1 702.0850
## 17 0 0 0 0.000 -1 971.1848
## 18 0 0 0 0.000 -1 1735.8274
## 19 0 0 0 0.000 -1 1382.9298
## 20 0 0 0 0.000 -1 2494.9158
## 21 0 0 0 0.000 -1 3023.1710
## 22 0 0 0 0.000 -1 4086.6390
## 23 0 0 0 0.000 -1 4716.7947
## 24 0 0 0 0.000 -1 4475.0434
## 25 0 0 0 0.000 -1 7002.1926
## 26 0 0 0 0.000 -1 10158.5433
## 27 0 0 0 0.000 -1 14422.6779
## 28 0 0 0 0.000 -1 5777.2421
## 29 0 0 0 0.000 -1 7495.8592
# some help:
# first if you try to replicate the code from the diamond's example, you will get an error.
# this is because randomForest expects no character variables, just numeric and factors.
# So how can we turn everything character into factor?
# Intuitively, you would probably do:
# telco_churn <- telco_churn %>%
# mutate(gender = as.factor(gender),
# SeniorCitizen = as.factor(SeniorCitizen),
# ...)
# but this is like doing the same action over and over again.
# wouldn't it be nice to just loop over everything, and if a column is of the wrong type (character)
# just convert it into factor?
# mutate_if() does exactly that.
# it needs a condition called .predicate, and a vector function to operate called .funs.
# It goes like this:
telco_churn <- telco_churn %>%
mutate_if(.predicate = funs(typeof(.)=="character"),
.funs = funs(as.factor(.)))
# What does it do?
# it's like looping
# for (i in 1:NCOL(telco_churn)){
# if (typeof(telco_churn[,i]) == "character") {
# telco_churn[,i] <- as.factor(telco_churn[,i])
# }
# }
# the syntax of
# typeof(.) == "character"
# is like "replace . with the vector you are currently checking"
# the syntax of as.factor(.) is like
# "replace . with the vector" that is a character and you need to type case into a factor
# the funs() function makes the expression explicit after replacing the . with the current vector
# now you can continue the exercise...
ביערות אקראיים ראינו כיצד חזאי בודד (עץ) משתכפל והופך לשילוב של הרבה חזאים. כאשר מבוצע שילוב של חזאים רבים הדבר מהווה פוטנציאל להפחית שגיאות שעשויות להופיע באופן מקומי בחזאים אשר נתונים פעמים רבות ל“גחמות הסטטיסטיקה” (או שגיאות הנובעות ממינימום מקומי או מהתאמת יתר).
גישה נוספת חוץ מיערות אקראיים היא גישת ה-Boosting. היא יכולה להתאים לגישות שונות (לאו דווקא להכללה של עצים), אך פה נדגים אותה בהקשר העצים.
נניח שהבעיה שלנו היא בעיית רגרסיה (חיזוי ערך של משתנה רציף). ב-Boosting, בכל שלב האלגוריתם יבנה עץ, שהמטרה שלו היא חיזוי השגיאה (לא הערך האמיתי של \(y\) אלא השגיאה הצפוייה בהתבסס על כל העצים שנבנו עד כה).
המודל מתווסף כסכום לכל יתר המודל שחושבו עד כה, עם פרמטר “הקטנה” \(\lambda\).
במילים אחרות, האלגוריתם מרכיב סכום של הרבה עצים קטנים, כשכל פעם הוא נותן דגש על צמצום השגיאות שהתקבלו עד כה.
Pseudo code:
Set r = y, f(x)=0
For b=1, 2,..., B repeat:
Fit a tree, f_curr, to the data (X,r)
Update f by adding current learned tree:
f <- f + lambda*f_curr
Update the residuals
r - r-lambda*f_curr
Output f
במקרה של בעיות סיווג, העדכון של המודלים מתבצע על ידי בניה בכל שלב של מודל סיווג (לדוגמה עץ), תוך כדי מתן משקל גדול יותר לתצפיות אשר הסיווג שלהן שגוי.
ב-R יש שתי חבילות המשמשות לboosting:
השתמשו בחבילת xgboost, בפקודת xgboost כדי לייצר חזאי לנטישת לקוחות. הצגת התוצאות בתרשים ה-ROC. האם ביצוע boosting שיפר את החזאים?
שימו לב, פקודת xgboost דורשת הכנה של מבנה הנתונים למטריצה. הנה קוד שיסייע לכם בהכנת המטריצה. השלימו את ה-XXX עם הפקודות / שמות משתנים המתאימים. שימו לב לשימוש ב-mutate_all, הדומה לפקודה שהשתמשנו בה קודם mutate_if.
# prepare the data
telco_churn_for_boost <- telco_churn %>%
filter(is_train) %>%
select(XXX:XXX)
dtrain <- xgb.DMatrix(telco_churn_for_boost %>%
mutate_all(funs(as.numeric(XXX))) %>%
select(-XXX) %>%
as.matrix(),
label = XXX == XXX)
# building the boost predictor
churn_boost <- xgboost::xgboost(data = dtrain,
nrounds = XXX,
params =
list(objective=XXX,
booster=XXX))
האם יש פרמטרים של העצים הנבנים במהלך אלגוריתם ה-boosting שאתם יכולים לשנות, כך שישפרו את החזאי?