ראינו ביחידות הקודמות אלגוריתמים שונים. במהלך היחידות הזכרנו מושג הנקרא over-fitting או “התאמת יתר”. כמו כן, הקפדנו בכל החישובים לחלק את הנתונים ל-train ול-test.

בחלק זה נדגים את החשיבות של החלוקה ל-train/test ואת הזהירות המתבקשת כדי להימנע מהתאמת יתר.

התאמת יתר קוראת כאשר מספר המשתנים המסבירים, \(p\) גדול מאוד בהשוואה לגודל המדגם \(n\).

במילים אחרות, ככל שמספר המשתנים גדל, אז אלגוריתמים שונים מצליחים למצוא מודל שמתאים לנתונים הנצפים בצורה טובה, כביכול מתאר קשר סטטיסטי, אך בפועל הסיבה לכך היא שהאלגוריתם מתאר קשר מקרי-אקראי בעזרת דרגות החופש הרבות שקיימות בפרמטרים.

הדגמה:

library(tidyverse)

xvars <- data.frame(matrix(runif(100*95), ncol=95))
overfitting <- tibble(y = runif(100)) %>%
  bind_cols(xvars)

glimpse(overfitting)
## Observations: 100
## Variables: 96
## $ y   <dbl> 0.7089986, 0.2903008, 0.2325977, 0.5330833, 0.1241898, 0.7...
## $ X1  <dbl> 0.619575535, 0.899619185, 0.526324802, 0.797330922, 0.8003...
## $ X2  <dbl> 0.3356246, 0.1855314, 0.2111435, 0.4857264, 0.2592875, 0.4...
## $ X3  <dbl> 0.05733814, 0.33154772, 0.91916158, 0.19071112, 0.86667501...
## $ X4  <dbl> 0.07075963, 0.17849404, 0.56142045, 0.73903736, 0.90931825...
## $ X5  <dbl> 0.90530667, 0.69761504, 0.76214690, 0.13297892, 0.72014595...
## $ X6  <dbl> 0.926684984, 0.917220576, 0.899580207, 0.372546137, 0.4372...
## $ X7  <dbl> 0.62433765, 0.20075305, 0.97032955, 0.24861709, 0.23294241...
## $ X8  <dbl> 0.1845216, 0.9790208, 0.1168240, 0.5551002, 0.4034280, 0.6...
## $ X9  <dbl> 0.9975084, 0.6891551, 0.7544424, 0.1501899, 0.6474128, 0.6...
## $ X10 <dbl> 0.167975825, 0.723071146, 0.543455848, 0.032874736, 0.6633...
## $ X11 <dbl> 0.29473194, 0.29811168, 0.66438858, 0.58404737, 0.35506948...
## $ X12 <dbl> 0.74933204, 0.39332822, 0.17481607, 0.72169066, 0.18004782...
## $ X13 <dbl> 0.75818476, 0.45023901, 0.82799755, 0.50486737, 0.28818241...
## $ X14 <dbl> 0.02500514, 0.53166926, 0.04214480, 0.87346605, 0.74550727...
## $ X15 <dbl> 0.67567010, 0.34008779, 0.11339664, 0.40079429, 0.51798371...
## $ X16 <dbl> 0.1442445, 0.2809641, 0.5323024, 0.6707393, 0.8546232, 0.6...
## $ X17 <dbl> 0.47954964, 0.57758237, 0.03613206, 0.19995500, 0.88160174...
## $ X18 <dbl> 0.77517537, 0.55127005, 0.98536289, 0.75764646, 0.28343731...
## $ X19 <dbl> 0.788750990, 0.809515184, 0.619781677, 0.968460454, 0.7808...
## $ X20 <dbl> 0.28163121, 0.40658971, 0.35327382, 0.94752835, 0.74938901...
## $ X21 <dbl> 0.1723946, 0.6720742, 0.7847495, 0.9581626, 0.7083294, 0.9...
## $ X22 <dbl> 0.94247637, 0.94415301, 0.38278227, 0.31410514, 0.14573959...
## $ X23 <dbl> 0.25266289, 0.50520089, 0.25369704, 0.37217720, 0.22047173...
## $ X24 <dbl> 0.007308053, 0.466543796, 0.424070933, 0.694347189, 0.8798...
## $ X25 <dbl> 0.970274905, 0.583504309, 0.800831113, 0.167429144, 0.4228...
## $ X26 <dbl> 0.53167232, 0.95940371, 0.31224646, 0.99837161, 0.61569190...
## $ X27 <dbl> 0.57421502, 0.46311628, 0.27905687, 0.25135528, 0.35998749...
## $ X28 <dbl> 0.76704330, 0.21635119, 0.31283717, 0.65894494, 0.86053187...
## $ X29 <dbl> 0.12292826, 0.31934540, 0.91608865, 0.01584728, 0.97940715...
## $ X30 <dbl> 0.22891933, 0.56078554, 0.70113860, 0.05932347, 0.49508451...
## $ X31 <dbl> 0.85305427, 0.06841765, 0.23349174, 0.48582019, 0.88397211...
## $ X32 <dbl> 0.03113444, 0.30614071, 0.89450422, 0.22408335, 0.35976227...
## $ X33 <dbl> 0.02164700, 0.07401357, 0.83564407, 0.98764889, 0.09088551...
## $ X34 <dbl> 0.2696412, 0.7108775, 0.7341540, 0.2697968, 0.3624979, 0.6...
## $ X35 <dbl> 0.320782271, 0.588473998, 0.486805190, 0.925924307, 0.0381...
## $ X36 <dbl> 0.98755162, 0.98473189, 0.26018276, 0.15832839, 0.14243235...
## $ X37 <dbl> 0.2862225, 0.5370663, 0.4572374, 0.9198104, 0.2152340, 0.0...
## $ X38 <dbl> 0.052688550, 0.436476058, 0.061697636, 0.985184224, 0.9432...
## $ X39 <dbl> 0.43836777, 0.24706211, 0.50049517, 0.99940944, 0.99868643...
## $ X40 <dbl> 0.521275533, 0.395116090, 0.423826498, 0.554926199, 0.2291...
## $ X41 <dbl> 0.97311151, 0.47235503, 0.67833843, 0.19339717, 0.65781704...
## $ X42 <dbl> 0.248431407, 0.005146637, 0.969308545, 0.276617155, 0.1857...
## $ X43 <dbl> 0.5080576, 0.4427110, 0.6679994, 0.1347259, 0.6523113, 0.8...
## $ X44 <dbl> 0.40061522, 0.01997527, 0.79103114, 0.29346586, 0.53881140...
## $ X45 <dbl> 0.96380904, 0.23399784, 0.95964714, 0.04463061, 0.31930508...
## $ X46 <dbl> 0.28601891, 0.81181235, 0.44594990, 0.63458057, 0.82823654...
## $ X47 <dbl> 0.02077297, 0.07269980, 0.23653647, 0.12454724, 0.02816920...
## $ X48 <dbl> 0.90443248, 0.06945828, 0.18666305, 0.91806952, 0.15030980...
## $ X49 <dbl> 0.77868224, 0.20669518, 0.71549280, 0.02068446, 0.11806395...
## $ X50 <dbl> 0.95519671, 0.93217992, 0.43177897, 0.39629248, 0.27116443...
## $ X51 <dbl> 0.862101254, 0.770861074, 0.318981064, 0.168083573, 0.7692...
## $ X52 <dbl> 0.13480913, 0.34445786, 0.29208633, 0.91072239, 0.09709552...
## $ X53 <dbl> 0.82018167, 0.14150137, 0.68362884, 0.78768875, 0.91225545...
## $ X54 <dbl> 0.73341617, 0.79551783, 0.49868291, 0.49469305, 0.91240890...
## $ X55 <dbl> 0.419185707, 0.309445252, 0.173036704, 0.064510627, 0.0010...
## $ X56 <dbl> 0.431528284, 0.368719596, 0.492042887, 0.261681683, 0.6668...
## $ X57 <dbl> 0.18112053, 0.99507521, 0.08505161, 0.54702360, 0.46818076...
## $ X58 <dbl> 0.04648844, 0.59922984, 0.45472771, 0.12751548, 0.38072173...
## $ X59 <dbl> 0.40068713, 0.03837728, 0.89181267, 0.48057507, 0.93588263...
## $ X60 <dbl> 0.03430447, 0.36435655, 0.88154018, 0.76483513, 0.45207317...
## $ X61 <dbl> 0.73432849, 0.80960489, 0.78099743, 0.72469981, 0.48719337...
## $ X62 <dbl> 0.10556438, 0.02151259, 0.68868997, 0.60759973, 0.93068182...
## $ X63 <dbl> 0.049169683, 0.002650928, 0.838758925, 0.928084582, 0.3498...
## $ X64 <dbl> 0.41719576, 0.04931209, 0.79958266, 0.74351909, 0.95615034...
## $ X65 <dbl> 0.07419834, 0.21196549, 0.79339840, 0.44937311, 0.97695927...
## $ X66 <dbl> 0.95155465, 0.92778168, 0.48042207, 0.03327807, 0.67837089...
## $ X67 <dbl> 0.171465621, 0.299915533, 0.004499343, 0.434086087, 0.0779...
## $ X68 <dbl> 0.01796292, 0.63638499, 0.45528049, 0.08723639, 0.20832351...
## $ X69 <dbl> 0.4405857, 0.8731447, 0.1244418, 0.1751426, 0.6632684, 0.9...
## $ X70 <dbl> 0.95384931, 0.14038067, 0.13566131, 0.38467613, 0.03042929...
## $ X71 <dbl> 0.85806194, 0.56460911, 0.83737858, 0.30946500, 0.64349164...
## $ X72 <dbl> 0.53777852, 0.21344026, 0.07911339, 0.29229307, 0.16083507...
## $ X73 <dbl> 0.89964483, 0.06764685, 0.67411464, 0.80158839, 0.73114918...
## $ X74 <dbl> 0.59048599, 0.81797222, 0.87230988, 0.13529926, 0.05424756...
## $ X75 <dbl> 0.686163759, 0.335998049, 0.756478863, 0.001556719, 0.2076...
## $ X76 <dbl> 0.76971048, 0.78360143, 0.25307252, 0.18701104, 0.33148095...
## $ X77 <dbl> 0.7759770, 0.8177609, 0.7698416, 0.8594684, 0.4142036, 0.6...
## $ X78 <dbl> 0.32011005, 0.35041643, 0.65890719, 0.35551257, 0.29391438...
## $ X79 <dbl> 0.84099329, 0.24179980, 0.23983253, 0.96846815, 0.47056007...
## $ X80 <dbl> 0.861469706, 0.729557837, 0.757376592, 0.657397750, 0.9479...
## $ X81 <dbl> 0.52205236, 0.93355915, 0.65700572, 0.04703858, 0.02417336...
## $ X82 <dbl> 0.750052834, 0.076307797, 0.952641464, 0.657753644, 0.3252...
## $ X83 <dbl> 0.52380760, 0.59691990, 0.29589242, 0.80793851, 0.71335933...
## $ X84 <dbl> 0.95124598, 0.47327211, 0.28512664, 0.11815087, 0.41114094...
## $ X85 <dbl> 0.33718894, 0.57207048, 0.65111439, 0.79303920, 0.36162134...
## $ X86 <dbl> 0.44982642, 0.63153665, 0.52759030, 0.08598975, 0.59544034...
## $ X87 <dbl> 0.09052119, 0.29012751, 0.89028523, 0.04323099, 0.22819343...
## $ X88 <dbl> 0.89368520, 0.48912468, 0.74618919, 0.68833499, 0.35267304...
## $ X89 <dbl> 0.87702854, 0.76879001, 0.77406906, 0.01669727, 0.92174791...
## $ X90 <dbl> 0.91865363, 0.10844596, 0.65514847, 0.55367396, 0.18522627...
## $ X91 <dbl> 0.788593944, 0.670447605, 0.973700038, 0.740475087, 0.4282...
## $ X92 <dbl> 0.6621391, 0.8742585, 0.8830044, 0.7829854, 0.6540070, 0.9...
## $ X93 <dbl> 0.80757343, 0.80161619, 0.91951942, 0.11957843, 0.06100728...
## $ X94 <dbl> 0.44396502, 0.67364542, 0.71246677, 0.78108032, 0.66719770...
## $ X95 <dbl> 0.3311924334, 0.2290834789, 0.8129605544, 0.7725442853, 0....
ggplot(overfitting, aes(y)) + geom_histogram()

# these are just uniformly distributed numbers, should have no kind of relationship between variables
# here's a model with just a few X's, and no overfit. The model is insignificant.
# the only significant coefficient beta is the intercept (which is roughly equal to the average of y)
lm_no_overfit <- lm(data = overfitting,
                    formula = y ~ X1 + X2 + X3)
summary(lm_no_overfit)
## 
## Call:
## lm(formula = y ~ X1 + X2 + X3, data = overfitting)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -0.47712 -0.23430  0.00585  0.23836  0.49926 
## 
## Coefficients:
##             Estimate Std. Error t value Pr(>|t|)    
## (Intercept)  0.47886    0.08421   5.687  1.4e-07 ***
## X1           0.07653    0.09755   0.784    0.435    
## X2           0.06358    0.09539   0.667    0.507    
## X3          -0.06808    0.09903  -0.687    0.493    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 0.2767 on 96 degrees of freedom
## Multiple R-squared:  0.01836,    Adjusted R-squared:  -0.01232 
## F-statistic: 0.5984 on 3 and 96 DF,  p-value: 0.6176
# now, see what happens when we add all the 95 features
# mostly, look at the R^2. It's almost 1!
lm_overfit <- lm(data = overfitting,
                 formula = y ~ .)
summary(lm_overfit)
## 
## Call:
## lm(formula = y ~ ., data = overfitting)
## 
## Residuals:
##          1          2          3          4          5          6 
## -0.0133397 -0.0046578 -0.0891290 -0.0736334 -0.0515931  0.0116765 
##          7          8          9         10         11         12 
## -0.0021997 -0.0109935 -0.0572086 -0.0256170  0.0656870 -0.0117807 
##         13         14         15         16         17         18 
##  0.0847625  0.0206311 -0.0985456  0.0779288  0.0220502  0.0919690 
##         19         20         21         22         23         24 
##  0.0295509 -0.0774051 -0.0993854  0.0583273 -0.0343006 -0.0255282 
##         25         26         27         28         29         30 
## -0.0126529  0.0975615 -0.1881224 -0.0695484 -0.0789757 -0.0772258 
##         31         32         33         34         35         36 
## -0.0350661  0.0174622  0.0746242  0.0381204  0.0198678 -0.0279989 
##         37         38         39         40         41         42 
## -0.1698548  0.1401285  0.0235570 -0.1442771  0.0722703  0.0188309 
##         43         44         45         46         47         48 
##  0.0447015 -0.0105864 -0.0141958 -0.1007526 -0.0807882 -0.1121918 
##         49         50         51         52         53         54 
## -0.1602813 -0.0171016  0.0293796  0.1189691 -0.1034747  0.0640453 
##         55         56         57         58         59         60 
## -0.0340786  0.1399913 -0.0447618  0.0591502 -0.0433596  0.0334739 
##         61         62         63         64         65         66 
##  0.0304887 -0.1232479  0.0346934  0.0920425  0.0138008 -0.0424423 
##         67         68         69         70         71         72 
## -0.0180623  0.0265269 -0.1018066 -0.0760780  0.0387847 -0.0365744 
##         73         74         75         76         77         78 
##  0.0007957  0.0387094  0.1461009 -0.0757832 -0.0115106 -0.0406971 
##         79         80         81         82         83         84 
##  0.1093784  0.1116701  0.1337111  0.0969569  0.0650528  0.0212230 
##         85         86         87         88         89         90 
##  0.0705906 -0.0782842 -0.1077949 -0.0018812 -0.0248346  0.0814677 
##         91         92         93         94         95         96 
##  0.0554963 -0.0858042 -0.0352045 -0.0328825  0.0498148  0.0546531 
##         97         98         99        100 
##  0.0397033  0.0555541  0.1554900  0.1160781 
## 
## Coefficients:
##              Estimate Std. Error t value Pr(>|t|)
## (Intercept) -4.122810   3.444076  -1.197    0.297
## X1           0.351015   0.380474   0.923    0.408
## X2          -0.426239   0.640022  -0.666    0.542
## X3          -0.248992   0.717441  -0.347    0.746
## X4           0.430994   0.715993   0.602    0.580
## X5           0.174900   1.164430   0.150    0.888
## X6           0.004448   0.633271   0.007    0.995
## X7           0.168347   0.667188   0.252    0.813
## X8           0.362422   0.425257   0.852    0.442
## X9          -0.930122   0.732536  -1.270    0.273
## X10         -0.243125   0.664293  -0.366    0.733
## X11          0.327994   0.488677   0.671    0.539
## X12         -0.298612   0.409995  -0.728    0.507
## X13         -0.065611   1.022115  -0.064    0.952
## X14          0.084612   0.739953   0.114    0.914
## X15         -0.326629   0.663982  -0.492    0.649
## X16          0.135523   0.433613   0.313    0.770
## X17          0.173984   0.526083   0.331    0.757
## X18         -0.604196   0.797296  -0.758    0.491
## X19          0.383549   0.455734   0.842    0.447
## X20         -0.570253   0.596491  -0.956    0.393
## X21         -0.002498   0.590407  -0.004    0.997
## X22          0.189672   0.490991   0.386    0.719
## X23          0.659353   0.741043   0.890    0.424
## X24          0.219728   0.425720   0.516    0.633
## X25         -0.631699   0.634657  -0.995    0.376
## X26          0.226392   0.588597   0.385    0.720
## X27          0.233029   0.473520   0.492    0.648
## X28          0.087991   0.831848   0.106    0.921
## X29         -0.325509   0.777508  -0.419    0.697
## X30         -0.309381   0.562588  -0.550    0.612
## X31          0.200576   0.528077   0.380    0.723
## X32          0.090398   0.849877   0.106    0.920
## X33          0.689791   0.580762   1.188    0.301
## X34          0.634693   1.140410   0.557    0.608
## X35          0.885712   0.686887   1.289    0.267
## X36          0.443227   0.697490   0.635    0.560
## X37         -0.067152   0.338836  -0.198    0.853
## X38          0.273851   0.512294   0.535    0.621
## X39          0.472776   0.542641   0.871    0.433
## X40          0.152980   0.735907   0.208    0.845
## X41          0.196330   0.571776   0.343    0.749
## X42         -0.059716   0.666692  -0.090    0.933
## X43         -0.590360   0.785326  -0.752    0.494
## X44          0.517218   0.900917   0.574    0.597
## X45          0.105339   0.594327   0.177    0.868
## X46         -0.018125   0.407940  -0.044    0.967
## X47          0.585905   0.661392   0.886    0.426
## X48         -0.175085   0.454045  -0.386    0.719
## X49          0.231498   0.755916   0.306    0.775
## X50          0.122044   0.481906   0.253    0.813
## X51          0.298945   0.700799   0.427    0.692
## X52         -0.195946   0.516324  -0.380    0.724
## X53          0.440193   0.688366   0.639    0.557
## X54         -0.539260   0.374728  -1.439    0.224
## X55          0.333631   0.921059   0.362    0.736
## X56          0.783178   0.664781   1.178    0.304
## X57          0.648622   0.978015   0.663    0.543
## X58         -0.068926   0.568410  -0.121    0.909
## X59         -0.167267   0.457221  -0.366    0.733
## X60         -0.497113   0.375513  -1.324    0.256
## X61         -0.687913   0.586053  -1.174    0.306
## X62          0.132609   0.646357   0.205    0.847
## X63         -0.166107   1.016299  -0.163    0.878
## X64          0.637944   0.637305   1.001    0.373
## X65          0.586959   0.571207   1.028    0.362
## X66          0.310789   0.400372   0.776    0.481
## X67         -0.706332   0.554787  -1.273    0.272
## X68          0.101024   0.433206   0.233    0.827
## X69          0.935187   0.727906   1.285    0.268
## X70          0.373803   0.798914   0.468    0.664
## X71          0.156401   0.532277   0.294    0.784
## X72         -0.436072   0.617317  -0.706    0.519
## X73         -0.325023   0.476374  -0.682    0.533
## X74         -0.567945   1.469147  -0.387    0.719
## X75          0.690781   0.575654   1.200    0.296
## X76          0.070859   0.504427   0.140    0.895
## X77          0.247754   0.876815   0.283    0.792
## X78          0.465667   0.581472   0.801    0.468
## X79         -0.280437   0.447860  -0.626    0.565
## X80          0.066715   0.815171   0.082    0.939
## X81         -0.198953   0.693930  -0.287    0.789
## X82         -0.177796   0.572976  -0.310    0.772
## X83          0.019106   0.667048   0.029    0.979
## X84          0.026787   0.516964   0.052    0.961
## X85         -0.363488   0.398647  -0.912    0.413
## X86         -0.131453   0.625147  -0.210    0.844
## X87         -0.206379   0.580365  -0.356    0.740
## X88          1.685545   0.868952   1.940    0.124
## X89          0.194877   0.554727   0.351    0.743
## X90          1.072297   0.652313   1.644    0.176
## X91          0.015586   0.734612   0.021    0.984
## X92          0.187272   0.688340   0.272    0.799
## X93          0.724166   0.959864   0.754    0.493
## X94         -0.135233   0.461560  -0.293    0.784
## X95          0.071528   0.664853   0.108    0.920
## 
## Residual standard error: 0.3769 on 4 degrees of freedom
## Multiple R-squared:  0.9241, Adjusted R-squared:  -0.8787 
## F-statistic: 0.5126 on 95 and 4 DF,  p-value: 0.8917
# now, see the errors of each model
overfitting <- overfitting %>% 
  mutate(res_no_overfit = y - predict(lm_no_overfit, newdata = overfitting),
         res_overfit = y - predict(lm_overfit, newdata = overfitting))

overfitting %>%
  summarize(mean(abs(res_no_overfit)),
            mean(abs(res_overfit)))
## # A tibble: 1 x 2
##   `mean(abs(res_no_overfit))` `mean(abs(res_overfit))`
##                         <dbl>                    <dbl>
## 1                       0.234                   0.0619
# 80%+ reduction in mean absolute residual error!

עד כה עבדנו בלי חלוקה ל-train/test, ועל פניו זה נראה כאילו המודל שהתאמנו עם הרבה משתנים, הוא ממש טוב. כפי שניחשתם, זה בלוף…

כעת נחזור על התרגיל, רק שהפעם נמדוד את עצמנו ב-test set.

overfitting <- overfitting %>%
  mutate(is_train = runif(nrow(overfitting)) < 0.8)

lm_overfit_train <- lm(data = overfitting %>% filter(is_train),
                       formula = y ~ .)

overfitting <- overfitting %>%
  mutate(res_overfit_train = y - predict(lm_overfit_train, newdata = overfitting))
## Warning in predict.lm(lm_overfit_train, newdata = overfitting): prediction
## from a rank-deficient fit may be misleading
overfitting %>%
  filter(!is_train) %>%
  summarize(mean(abs(res_no_overfit)),
            mean(abs(res_overfit)),
            mean(abs(res_overfit_train)))
## # A tibble: 1 x 3
##   `mean(abs(res_no_overfi~ `mean(abs(res_overfi~ `mean(abs(res_overfit_tr~
##                      <dbl>                 <dbl>                     <dbl>
## 1                    0.210                0.0806                      98.8
# Now the "true face" of the model is discovered. See how high the error rate of the test set is!
# Beware of overfitting models. Always use train/test. Watch out for n and p.

לסיכום

שימו לב שאין “כלל ברזל” בנוגע ליחס בין \(n\) לבין \(p\), אבל שיעור הטעות של ה-test set הוא הרבה פעמים בעל משמעות עסקית, ודרך משמעות זו ניתן להבין האם המודל עוזר או שלא. כמו כן, ניתן להשוות בין מודל בסיסי-נומינלי, לבין המודל שלכם, ולראות מה מידת התרומה של המודל המורכב יותר.