ראינו ביחידות הקודמות אלגוריתמים שונים. במהלך היחידות הזכרנו מושג הנקרא over-fitting או “התאמת יתר”. כמו כן, הקפדנו בכל החישובים לחלק את הנתונים ל-train ול-test.
בחלק זה נדגים את החשיבות של החלוקה ל-train/test ואת הזהירות המתבקשת כדי להימנע מהתאמת יתר.
התאמת יתר קוראת כאשר מספר המשתנים המסבירים, \(p\) גדול מאוד בהשוואה לגודל המדגם \(n\).
במילים אחרות, ככל שמספר המשתנים גדל, אז אלגוריתמים שונים מצליחים למצוא מודל שמתאים לנתונים הנצפים בצורה טובה, כביכול מתאר קשר סטטיסטי, אך בפועל הסיבה לכך היא שהאלגוריתם מתאר קשר מקרי-אקראי בעזרת דרגות החופש הרבות שקיימות בפרמטרים.
הדגמה:
library(tidyverse)
xvars <- data.frame(matrix(runif(100*95), ncol=95))
overfitting <- tibble(y = runif(100)) %>%
bind_cols(xvars)
glimpse(overfitting)
## Observations: 100
## Variables: 96
## $ y <dbl> 0.7089986, 0.2903008, 0.2325977, 0.5330833, 0.1241898, 0.7...
## $ X1 <dbl> 0.619575535, 0.899619185, 0.526324802, 0.797330922, 0.8003...
## $ X2 <dbl> 0.3356246, 0.1855314, 0.2111435, 0.4857264, 0.2592875, 0.4...
## $ X3 <dbl> 0.05733814, 0.33154772, 0.91916158, 0.19071112, 0.86667501...
## $ X4 <dbl> 0.07075963, 0.17849404, 0.56142045, 0.73903736, 0.90931825...
## $ X5 <dbl> 0.90530667, 0.69761504, 0.76214690, 0.13297892, 0.72014595...
## $ X6 <dbl> 0.926684984, 0.917220576, 0.899580207, 0.372546137, 0.4372...
## $ X7 <dbl> 0.62433765, 0.20075305, 0.97032955, 0.24861709, 0.23294241...
## $ X8 <dbl> 0.1845216, 0.9790208, 0.1168240, 0.5551002, 0.4034280, 0.6...
## $ X9 <dbl> 0.9975084, 0.6891551, 0.7544424, 0.1501899, 0.6474128, 0.6...
## $ X10 <dbl> 0.167975825, 0.723071146, 0.543455848, 0.032874736, 0.6633...
## $ X11 <dbl> 0.29473194, 0.29811168, 0.66438858, 0.58404737, 0.35506948...
## $ X12 <dbl> 0.74933204, 0.39332822, 0.17481607, 0.72169066, 0.18004782...
## $ X13 <dbl> 0.75818476, 0.45023901, 0.82799755, 0.50486737, 0.28818241...
## $ X14 <dbl> 0.02500514, 0.53166926, 0.04214480, 0.87346605, 0.74550727...
## $ X15 <dbl> 0.67567010, 0.34008779, 0.11339664, 0.40079429, 0.51798371...
## $ X16 <dbl> 0.1442445, 0.2809641, 0.5323024, 0.6707393, 0.8546232, 0.6...
## $ X17 <dbl> 0.47954964, 0.57758237, 0.03613206, 0.19995500, 0.88160174...
## $ X18 <dbl> 0.77517537, 0.55127005, 0.98536289, 0.75764646, 0.28343731...
## $ X19 <dbl> 0.788750990, 0.809515184, 0.619781677, 0.968460454, 0.7808...
## $ X20 <dbl> 0.28163121, 0.40658971, 0.35327382, 0.94752835, 0.74938901...
## $ X21 <dbl> 0.1723946, 0.6720742, 0.7847495, 0.9581626, 0.7083294, 0.9...
## $ X22 <dbl> 0.94247637, 0.94415301, 0.38278227, 0.31410514, 0.14573959...
## $ X23 <dbl> 0.25266289, 0.50520089, 0.25369704, 0.37217720, 0.22047173...
## $ X24 <dbl> 0.007308053, 0.466543796, 0.424070933, 0.694347189, 0.8798...
## $ X25 <dbl> 0.970274905, 0.583504309, 0.800831113, 0.167429144, 0.4228...
## $ X26 <dbl> 0.53167232, 0.95940371, 0.31224646, 0.99837161, 0.61569190...
## $ X27 <dbl> 0.57421502, 0.46311628, 0.27905687, 0.25135528, 0.35998749...
## $ X28 <dbl> 0.76704330, 0.21635119, 0.31283717, 0.65894494, 0.86053187...
## $ X29 <dbl> 0.12292826, 0.31934540, 0.91608865, 0.01584728, 0.97940715...
## $ X30 <dbl> 0.22891933, 0.56078554, 0.70113860, 0.05932347, 0.49508451...
## $ X31 <dbl> 0.85305427, 0.06841765, 0.23349174, 0.48582019, 0.88397211...
## $ X32 <dbl> 0.03113444, 0.30614071, 0.89450422, 0.22408335, 0.35976227...
## $ X33 <dbl> 0.02164700, 0.07401357, 0.83564407, 0.98764889, 0.09088551...
## $ X34 <dbl> 0.2696412, 0.7108775, 0.7341540, 0.2697968, 0.3624979, 0.6...
## $ X35 <dbl> 0.320782271, 0.588473998, 0.486805190, 0.925924307, 0.0381...
## $ X36 <dbl> 0.98755162, 0.98473189, 0.26018276, 0.15832839, 0.14243235...
## $ X37 <dbl> 0.2862225, 0.5370663, 0.4572374, 0.9198104, 0.2152340, 0.0...
## $ X38 <dbl> 0.052688550, 0.436476058, 0.061697636, 0.985184224, 0.9432...
## $ X39 <dbl> 0.43836777, 0.24706211, 0.50049517, 0.99940944, 0.99868643...
## $ X40 <dbl> 0.521275533, 0.395116090, 0.423826498, 0.554926199, 0.2291...
## $ X41 <dbl> 0.97311151, 0.47235503, 0.67833843, 0.19339717, 0.65781704...
## $ X42 <dbl> 0.248431407, 0.005146637, 0.969308545, 0.276617155, 0.1857...
## $ X43 <dbl> 0.5080576, 0.4427110, 0.6679994, 0.1347259, 0.6523113, 0.8...
## $ X44 <dbl> 0.40061522, 0.01997527, 0.79103114, 0.29346586, 0.53881140...
## $ X45 <dbl> 0.96380904, 0.23399784, 0.95964714, 0.04463061, 0.31930508...
## $ X46 <dbl> 0.28601891, 0.81181235, 0.44594990, 0.63458057, 0.82823654...
## $ X47 <dbl> 0.02077297, 0.07269980, 0.23653647, 0.12454724, 0.02816920...
## $ X48 <dbl> 0.90443248, 0.06945828, 0.18666305, 0.91806952, 0.15030980...
## $ X49 <dbl> 0.77868224, 0.20669518, 0.71549280, 0.02068446, 0.11806395...
## $ X50 <dbl> 0.95519671, 0.93217992, 0.43177897, 0.39629248, 0.27116443...
## $ X51 <dbl> 0.862101254, 0.770861074, 0.318981064, 0.168083573, 0.7692...
## $ X52 <dbl> 0.13480913, 0.34445786, 0.29208633, 0.91072239, 0.09709552...
## $ X53 <dbl> 0.82018167, 0.14150137, 0.68362884, 0.78768875, 0.91225545...
## $ X54 <dbl> 0.73341617, 0.79551783, 0.49868291, 0.49469305, 0.91240890...
## $ X55 <dbl> 0.419185707, 0.309445252, 0.173036704, 0.064510627, 0.0010...
## $ X56 <dbl> 0.431528284, 0.368719596, 0.492042887, 0.261681683, 0.6668...
## $ X57 <dbl> 0.18112053, 0.99507521, 0.08505161, 0.54702360, 0.46818076...
## $ X58 <dbl> 0.04648844, 0.59922984, 0.45472771, 0.12751548, 0.38072173...
## $ X59 <dbl> 0.40068713, 0.03837728, 0.89181267, 0.48057507, 0.93588263...
## $ X60 <dbl> 0.03430447, 0.36435655, 0.88154018, 0.76483513, 0.45207317...
## $ X61 <dbl> 0.73432849, 0.80960489, 0.78099743, 0.72469981, 0.48719337...
## $ X62 <dbl> 0.10556438, 0.02151259, 0.68868997, 0.60759973, 0.93068182...
## $ X63 <dbl> 0.049169683, 0.002650928, 0.838758925, 0.928084582, 0.3498...
## $ X64 <dbl> 0.41719576, 0.04931209, 0.79958266, 0.74351909, 0.95615034...
## $ X65 <dbl> 0.07419834, 0.21196549, 0.79339840, 0.44937311, 0.97695927...
## $ X66 <dbl> 0.95155465, 0.92778168, 0.48042207, 0.03327807, 0.67837089...
## $ X67 <dbl> 0.171465621, 0.299915533, 0.004499343, 0.434086087, 0.0779...
## $ X68 <dbl> 0.01796292, 0.63638499, 0.45528049, 0.08723639, 0.20832351...
## $ X69 <dbl> 0.4405857, 0.8731447, 0.1244418, 0.1751426, 0.6632684, 0.9...
## $ X70 <dbl> 0.95384931, 0.14038067, 0.13566131, 0.38467613, 0.03042929...
## $ X71 <dbl> 0.85806194, 0.56460911, 0.83737858, 0.30946500, 0.64349164...
## $ X72 <dbl> 0.53777852, 0.21344026, 0.07911339, 0.29229307, 0.16083507...
## $ X73 <dbl> 0.89964483, 0.06764685, 0.67411464, 0.80158839, 0.73114918...
## $ X74 <dbl> 0.59048599, 0.81797222, 0.87230988, 0.13529926, 0.05424756...
## $ X75 <dbl> 0.686163759, 0.335998049, 0.756478863, 0.001556719, 0.2076...
## $ X76 <dbl> 0.76971048, 0.78360143, 0.25307252, 0.18701104, 0.33148095...
## $ X77 <dbl> 0.7759770, 0.8177609, 0.7698416, 0.8594684, 0.4142036, 0.6...
## $ X78 <dbl> 0.32011005, 0.35041643, 0.65890719, 0.35551257, 0.29391438...
## $ X79 <dbl> 0.84099329, 0.24179980, 0.23983253, 0.96846815, 0.47056007...
## $ X80 <dbl> 0.861469706, 0.729557837, 0.757376592, 0.657397750, 0.9479...
## $ X81 <dbl> 0.52205236, 0.93355915, 0.65700572, 0.04703858, 0.02417336...
## $ X82 <dbl> 0.750052834, 0.076307797, 0.952641464, 0.657753644, 0.3252...
## $ X83 <dbl> 0.52380760, 0.59691990, 0.29589242, 0.80793851, 0.71335933...
## $ X84 <dbl> 0.95124598, 0.47327211, 0.28512664, 0.11815087, 0.41114094...
## $ X85 <dbl> 0.33718894, 0.57207048, 0.65111439, 0.79303920, 0.36162134...
## $ X86 <dbl> 0.44982642, 0.63153665, 0.52759030, 0.08598975, 0.59544034...
## $ X87 <dbl> 0.09052119, 0.29012751, 0.89028523, 0.04323099, 0.22819343...
## $ X88 <dbl> 0.89368520, 0.48912468, 0.74618919, 0.68833499, 0.35267304...
## $ X89 <dbl> 0.87702854, 0.76879001, 0.77406906, 0.01669727, 0.92174791...
## $ X90 <dbl> 0.91865363, 0.10844596, 0.65514847, 0.55367396, 0.18522627...
## $ X91 <dbl> 0.788593944, 0.670447605, 0.973700038, 0.740475087, 0.4282...
## $ X92 <dbl> 0.6621391, 0.8742585, 0.8830044, 0.7829854, 0.6540070, 0.9...
## $ X93 <dbl> 0.80757343, 0.80161619, 0.91951942, 0.11957843, 0.06100728...
## $ X94 <dbl> 0.44396502, 0.67364542, 0.71246677, 0.78108032, 0.66719770...
## $ X95 <dbl> 0.3311924334, 0.2290834789, 0.8129605544, 0.7725442853, 0....
ggplot(overfitting, aes(y)) + geom_histogram()
# these are just uniformly distributed numbers, should have no kind of relationship between variables
# here's a model with just a few X's, and no overfit. The model is insignificant.
# the only significant coefficient beta is the intercept (which is roughly equal to the average of y)
lm_no_overfit <- lm(data = overfitting,
formula = y ~ X1 + X2 + X3)
summary(lm_no_overfit)
##
## Call:
## lm(formula = y ~ X1 + X2 + X3, data = overfitting)
##
## Residuals:
## Min 1Q Median 3Q Max
## -0.47712 -0.23430 0.00585 0.23836 0.49926
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 0.47886 0.08421 5.687 1.4e-07 ***
## X1 0.07653 0.09755 0.784 0.435
## X2 0.06358 0.09539 0.667 0.507
## X3 -0.06808 0.09903 -0.687 0.493
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.2767 on 96 degrees of freedom
## Multiple R-squared: 0.01836, Adjusted R-squared: -0.01232
## F-statistic: 0.5984 on 3 and 96 DF, p-value: 0.6176
# now, see what happens when we add all the 95 features
# mostly, look at the R^2. It's almost 1!
lm_overfit <- lm(data = overfitting,
formula = y ~ .)
summary(lm_overfit)
##
## Call:
## lm(formula = y ~ ., data = overfitting)
##
## Residuals:
## 1 2 3 4 5 6
## -0.0133397 -0.0046578 -0.0891290 -0.0736334 -0.0515931 0.0116765
## 7 8 9 10 11 12
## -0.0021997 -0.0109935 -0.0572086 -0.0256170 0.0656870 -0.0117807
## 13 14 15 16 17 18
## 0.0847625 0.0206311 -0.0985456 0.0779288 0.0220502 0.0919690
## 19 20 21 22 23 24
## 0.0295509 -0.0774051 -0.0993854 0.0583273 -0.0343006 -0.0255282
## 25 26 27 28 29 30
## -0.0126529 0.0975615 -0.1881224 -0.0695484 -0.0789757 -0.0772258
## 31 32 33 34 35 36
## -0.0350661 0.0174622 0.0746242 0.0381204 0.0198678 -0.0279989
## 37 38 39 40 41 42
## -0.1698548 0.1401285 0.0235570 -0.1442771 0.0722703 0.0188309
## 43 44 45 46 47 48
## 0.0447015 -0.0105864 -0.0141958 -0.1007526 -0.0807882 -0.1121918
## 49 50 51 52 53 54
## -0.1602813 -0.0171016 0.0293796 0.1189691 -0.1034747 0.0640453
## 55 56 57 58 59 60
## -0.0340786 0.1399913 -0.0447618 0.0591502 -0.0433596 0.0334739
## 61 62 63 64 65 66
## 0.0304887 -0.1232479 0.0346934 0.0920425 0.0138008 -0.0424423
## 67 68 69 70 71 72
## -0.0180623 0.0265269 -0.1018066 -0.0760780 0.0387847 -0.0365744
## 73 74 75 76 77 78
## 0.0007957 0.0387094 0.1461009 -0.0757832 -0.0115106 -0.0406971
## 79 80 81 82 83 84
## 0.1093784 0.1116701 0.1337111 0.0969569 0.0650528 0.0212230
## 85 86 87 88 89 90
## 0.0705906 -0.0782842 -0.1077949 -0.0018812 -0.0248346 0.0814677
## 91 92 93 94 95 96
## 0.0554963 -0.0858042 -0.0352045 -0.0328825 0.0498148 0.0546531
## 97 98 99 100
## 0.0397033 0.0555541 0.1554900 0.1160781
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -4.122810 3.444076 -1.197 0.297
## X1 0.351015 0.380474 0.923 0.408
## X2 -0.426239 0.640022 -0.666 0.542
## X3 -0.248992 0.717441 -0.347 0.746
## X4 0.430994 0.715993 0.602 0.580
## X5 0.174900 1.164430 0.150 0.888
## X6 0.004448 0.633271 0.007 0.995
## X7 0.168347 0.667188 0.252 0.813
## X8 0.362422 0.425257 0.852 0.442
## X9 -0.930122 0.732536 -1.270 0.273
## X10 -0.243125 0.664293 -0.366 0.733
## X11 0.327994 0.488677 0.671 0.539
## X12 -0.298612 0.409995 -0.728 0.507
## X13 -0.065611 1.022115 -0.064 0.952
## X14 0.084612 0.739953 0.114 0.914
## X15 -0.326629 0.663982 -0.492 0.649
## X16 0.135523 0.433613 0.313 0.770
## X17 0.173984 0.526083 0.331 0.757
## X18 -0.604196 0.797296 -0.758 0.491
## X19 0.383549 0.455734 0.842 0.447
## X20 -0.570253 0.596491 -0.956 0.393
## X21 -0.002498 0.590407 -0.004 0.997
## X22 0.189672 0.490991 0.386 0.719
## X23 0.659353 0.741043 0.890 0.424
## X24 0.219728 0.425720 0.516 0.633
## X25 -0.631699 0.634657 -0.995 0.376
## X26 0.226392 0.588597 0.385 0.720
## X27 0.233029 0.473520 0.492 0.648
## X28 0.087991 0.831848 0.106 0.921
## X29 -0.325509 0.777508 -0.419 0.697
## X30 -0.309381 0.562588 -0.550 0.612
## X31 0.200576 0.528077 0.380 0.723
## X32 0.090398 0.849877 0.106 0.920
## X33 0.689791 0.580762 1.188 0.301
## X34 0.634693 1.140410 0.557 0.608
## X35 0.885712 0.686887 1.289 0.267
## X36 0.443227 0.697490 0.635 0.560
## X37 -0.067152 0.338836 -0.198 0.853
## X38 0.273851 0.512294 0.535 0.621
## X39 0.472776 0.542641 0.871 0.433
## X40 0.152980 0.735907 0.208 0.845
## X41 0.196330 0.571776 0.343 0.749
## X42 -0.059716 0.666692 -0.090 0.933
## X43 -0.590360 0.785326 -0.752 0.494
## X44 0.517218 0.900917 0.574 0.597
## X45 0.105339 0.594327 0.177 0.868
## X46 -0.018125 0.407940 -0.044 0.967
## X47 0.585905 0.661392 0.886 0.426
## X48 -0.175085 0.454045 -0.386 0.719
## X49 0.231498 0.755916 0.306 0.775
## X50 0.122044 0.481906 0.253 0.813
## X51 0.298945 0.700799 0.427 0.692
## X52 -0.195946 0.516324 -0.380 0.724
## X53 0.440193 0.688366 0.639 0.557
## X54 -0.539260 0.374728 -1.439 0.224
## X55 0.333631 0.921059 0.362 0.736
## X56 0.783178 0.664781 1.178 0.304
## X57 0.648622 0.978015 0.663 0.543
## X58 -0.068926 0.568410 -0.121 0.909
## X59 -0.167267 0.457221 -0.366 0.733
## X60 -0.497113 0.375513 -1.324 0.256
## X61 -0.687913 0.586053 -1.174 0.306
## X62 0.132609 0.646357 0.205 0.847
## X63 -0.166107 1.016299 -0.163 0.878
## X64 0.637944 0.637305 1.001 0.373
## X65 0.586959 0.571207 1.028 0.362
## X66 0.310789 0.400372 0.776 0.481
## X67 -0.706332 0.554787 -1.273 0.272
## X68 0.101024 0.433206 0.233 0.827
## X69 0.935187 0.727906 1.285 0.268
## X70 0.373803 0.798914 0.468 0.664
## X71 0.156401 0.532277 0.294 0.784
## X72 -0.436072 0.617317 -0.706 0.519
## X73 -0.325023 0.476374 -0.682 0.533
## X74 -0.567945 1.469147 -0.387 0.719
## X75 0.690781 0.575654 1.200 0.296
## X76 0.070859 0.504427 0.140 0.895
## X77 0.247754 0.876815 0.283 0.792
## X78 0.465667 0.581472 0.801 0.468
## X79 -0.280437 0.447860 -0.626 0.565
## X80 0.066715 0.815171 0.082 0.939
## X81 -0.198953 0.693930 -0.287 0.789
## X82 -0.177796 0.572976 -0.310 0.772
## X83 0.019106 0.667048 0.029 0.979
## X84 0.026787 0.516964 0.052 0.961
## X85 -0.363488 0.398647 -0.912 0.413
## X86 -0.131453 0.625147 -0.210 0.844
## X87 -0.206379 0.580365 -0.356 0.740
## X88 1.685545 0.868952 1.940 0.124
## X89 0.194877 0.554727 0.351 0.743
## X90 1.072297 0.652313 1.644 0.176
## X91 0.015586 0.734612 0.021 0.984
## X92 0.187272 0.688340 0.272 0.799
## X93 0.724166 0.959864 0.754 0.493
## X94 -0.135233 0.461560 -0.293 0.784
## X95 0.071528 0.664853 0.108 0.920
##
## Residual standard error: 0.3769 on 4 degrees of freedom
## Multiple R-squared: 0.9241, Adjusted R-squared: -0.8787
## F-statistic: 0.5126 on 95 and 4 DF, p-value: 0.8917
# now, see the errors of each model
overfitting <- overfitting %>%
mutate(res_no_overfit = y - predict(lm_no_overfit, newdata = overfitting),
res_overfit = y - predict(lm_overfit, newdata = overfitting))
overfitting %>%
summarize(mean(abs(res_no_overfit)),
mean(abs(res_overfit)))
## # A tibble: 1 x 2
## `mean(abs(res_no_overfit))` `mean(abs(res_overfit))`
## <dbl> <dbl>
## 1 0.234 0.0619
# 80%+ reduction in mean absolute residual error!
עד כה עבדנו בלי חלוקה ל-train/test, ועל פניו זה נראה כאילו המודל שהתאמנו עם הרבה משתנים, הוא ממש טוב. כפי שניחשתם, זה בלוף…
כעת נחזור על התרגיל, רק שהפעם נמדוד את עצמנו ב-test set.
overfitting <- overfitting %>%
mutate(is_train = runif(nrow(overfitting)) < 0.8)
lm_overfit_train <- lm(data = overfitting %>% filter(is_train),
formula = y ~ .)
overfitting <- overfitting %>%
mutate(res_overfit_train = y - predict(lm_overfit_train, newdata = overfitting))
## Warning in predict.lm(lm_overfit_train, newdata = overfitting): prediction
## from a rank-deficient fit may be misleading
overfitting %>%
filter(!is_train) %>%
summarize(mean(abs(res_no_overfit)),
mean(abs(res_overfit)),
mean(abs(res_overfit_train)))
## # A tibble: 1 x 3
## `mean(abs(res_no_overfi~ `mean(abs(res_overfi~ `mean(abs(res_overfit_tr~
## <dbl> <dbl> <dbl>
## 1 0.210 0.0806 98.8
# Now the "true face" of the model is discovered. See how high the error rate of the test set is!
# Beware of overfitting models. Always use train/test. Watch out for n and p.
שימו לב שאין “כלל ברזל” בנוגע ליחס בין \(n\) לבין \(p\), אבל שיעור הטעות של ה-test set הוא הרבה פעמים בעל משמעות עסקית, ודרך משמעות זו ניתן להבין האם המודל עוזר או שלא. כמו כן, ניתן להשוות בין מודל בסיסי-נומינלי, לבין המודל שלכם, ולראות מה מידת התרומה של המודל המורכב יותר.