Kompromis pomiędzy obciążeniem a wariancją (bias vs variance)

Wybór modelu w problemach uczenia nadzorowanego wiąże się z realizacją dwóch sprzecznych celów

Modele złożone dobdrze dopasowują się do danych wyjściowych, ale charakteryzują się dużą zmiennością wartości wyjściowych. Ryzykiem jest nadmierne dopasowanie overfitting

Modele prostsze są obciążone dużym błędem systematyczny (bias) i ich zastosowanie niesie ryzyko niewystarczającego dopasowania (underfitting )

Trzecim składnikiem błędów generalizacji jest nieredukowalny błąd związany ze zmiennością danych

Regularyzacja

Lepszym rozwiązaniem jest gorsze dopasowanie do danych uczących przy równoczesnym ograniczeniu parametrów świadczących o potencjalnie dużym błędzie generalizacji.

Regularyzacja L2 - (Regresja grzbietowa, Ridge regression)

\(J(\theta ) = \frac{1}{2} (X \theta - \mathbf {y})^2 + \lambda \lVert{\theta}\rVert^2\) ,

gdzie

\(\lVert{\theta}\rVert^2\) oznacza normę \(L^2\), czyli: \(\theta^T\theta=\sum_{j=1}^{n}\theta_{j}^2\)

\(J(\theta ) = \frac{1}{2} \sum _{i=1}^m \left( h_\theta (x^{(i)} ) - y^{(i)}\right)^2 +\lambda \sum_{j=1}^{n}\theta_{j}^2\)

  • Kara nie dotyczy wyrazu wolnego!
  • W przypadku regresji grzbietowej należy wystandaryzować zmienne objaśniające

Gdybyśmy tego nie zrobili, to prognozy otrzymywane za pomocą dopasowaneh hiperpłaszczyzny zależałyby od jednostki, w której podane został zmienne objaśniające. Problem ten nie pojawia się, gdy \(\lambda=0\), czyli gdy używamy klasycznej metody najmniejszych kwadratów.

Zapis macierzowy (zakłada standaryzację zmiennej \(y\), dzięki czemu nie jest wymagany wyraz wolny):

\(J(\theta ) = \frac{1}{2} (X \theta - \mathbf {y})^2 + \lambda {\theta}^T\theta\)

\(\nabla_\theta J(\theta ) = X^T X \theta - X^T \mathbf {y} + 2\lambda \theta\)

Parametry minimalizujące funkcję kosztu w przypadku minimalizacji błędu kwadratowego (OLS - Ordinary Least Square, RSS Residuals Squared Sum):

\(\theta = (X^T X )^{-1} X^T \mathbf {y}\)

przyjmują dla regularyzacji L2 postać:

\(\theta = (X^T X +\lambda I)^{-1} X^T \mathbf {y}\)

Jeżeli w X jest wiele mocno skorelowanych atrybutów macierz \((X^TX)^-1\) może nie istnieć (liniowa zależność kolumn). Dodanie wystarczająco dużej wartości \(\lambda\) na przekątnej usuwa osobliwość macierzy.

Zadanie optymalizacji z ograniczeniami

Minimalizacja funkcji kosztu określonej następująco: \[\begin{equation*} J(\theta ) = \sum_{i=1}^m (y_i - \sum_{j=1}^n x_{ij}\theta_j)^2 + \lambda \sum_{j=1}^n \theta_j^2 \end{equation*}\] jest równoważna rozwiązaniu zadania: \[\begin{equation*} J(\theta ) = \sum_{i=1}^m (y_i - \sum_{j=1}^n x_{ij}\theta_j)^2 \end{equation*}\]

przy ograniczeniu:

\(\sum_{j=1}^n \theta_j^2 < c\) dla pewnej wartości \(c>0\)

reg<-lm(dist~speed, cars)

x<-as.matrix(cars$speed, length(cars$speed),1)
y<-as.matrix(cars$dist, length(cars$dist),1)

TPredict<-function(theta, x){
  x<-cbind( rep(1, nrow(x)) , x)  
  return (x %*% theta)
}

TCost<-function(theta, x, y){
  y_pred<-TPredict(theta, x)
  err <- (y_pred - y)^2
  return (sum(err) ) 
}

#Regularyzaja L2
TCost_L2<-function(theta, x, y, lambda){
  y_pred<-TPredict(theta, x)
  err <- (y_pred - y)^2
  penalty<-lambda * sum((theta[-1])^2)
  return (sum(err) + penalty) 
}




w1<-seq(-3,10, length.out = 200)

theta = as.matrix(cbind( rep(reg$coefficients[1], length(w1)), w1 ), 2, length(w1))


error_ols = apply(theta, 1, TCost_L2, x,y, 0)
error = apply(theta, 1, TCost_L2, x,y, 100000)
error1 = apply(theta, 1, TCost_L2, x,y, 50000)

plot(w1, error_ols, type='l', col='black')
lines(w1, error1, type='l', col='green')
lines(w1, error, type='l', col='red')

legend("top",  legend=c("OLS", "Lambda=5E04", "Lambda=1E05"),
       col=c("black", "green", "red"), lty=1:2, cex=0.8)

Regularyzacja L1 - (Lasso regression)

W przypadku regularyzacji L1 składnikiem regularyzującym jest norma L1 (czyli suma wartości bezwzględnych wag)

\(J(\theta ) = \frac{1}{2} \sum _{i=1}^m \left( h_\theta (x^{(i)} ) - y^{(i)}\right)^2 + \lambda \sum_{i=1}^{n}\vert\theta_i\vert\)

Funkcja ta nie jest różniczkowalna w zerze, więc nie zostanie podane rozwiązanie analityczne.

Mimo to można zastosować gradientowe metody optymalizacji.

Przykład jak wygląda funkcja kosztu w przypadku dodania regularyzacji L1:

reg<-lm(dist~speed, cars)

x<-as.matrix(cars$speed, length(cars$speed),1)
y<-as.matrix(cars$dist, length(cars$dist),1)

#Regularyzaja L1
TCost_L1<-function(theta, x, y, lambda){
  y_pred<-TPredict(theta, x)
  err <- (y_pred - y)^2
  penalty<-lambda * sum(abs(theta[-1]))
  return (sum(err) + penalty) 
}

w1<-seq(-3,10, length.out = 200)

theta = as.matrix(cbind( rep(reg$coefficients[1], length(w1)), w1 ), 2, length(w1))


error_ols = apply(theta, 1, TCost_L1, x,y, 0)
error = apply(theta, 1, TCost_L1, x,y, 100000)
error1 = apply(theta, 1, TCost_L1, x,y, 50000)

plot(w1, error_ols, type='l', col='black')
lines(w1, error1, type='l', col='green')
lines(w1, error, type='l', col='red')

legend("top",  legend=c("OLS", "Lambda=5E04", "Lambda=1E05"),
       col=c("black", "green", "red"), lty=1:2, cex=0.8)

Dobór parametru \(\lambda\)

Optymalną wartość parametru \(\lambda\) dobieramy metodą walidacji krzyżowej (cross-validation)

Regularyzacja w pakiecie glmnet

Model regresji liniowej dla prognozowania zmiennej Salary w zbiorze Hitters z pakietu ISLR.

Wykorzystana została funkcja model.matrix do budowy macierzy \(X\) oraz \(Y\).

W macierzy \(X\) nie występuje kolumna \(x_0\) zawierająca 1: pakiet glmnet domyślnie centruje zmienną Y oraz standaryzuje pozostałe wektory, w związku z czym nie ma konieczności pozostawienia kolumny.

data(Hitters, package='ISLR' )
head(Hitters)
##                   AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits
## -Andy Allanson      293   66     1   30  29    14     1    293    66
## -Alan Ashby         315   81     7   24  38    39    14   3449   835
## -Alvin Davis        479  130    18   66  72    76     3   1624   457
## -Andre Dawson       496  141    20   65  78    37    11   5628  1575
## -Andres Galarraga   321   87    10   39  42    30     2    396   101
## -Alfredo Griffin    594  169     4   74  51    35    11   4408  1133
##                   CHmRun CRuns CRBI CWalks League Division PutOuts Assists
## -Andy Allanson         1    30   29     14      A        E     446      33
## -Alan Ashby           69   321  414    375      N        W     632      43
## -Alvin Davis          63   224  266    263      A        W     880      82
## -Andre Dawson        225   828  838    354      N        E     200      11
## -Andres Galarraga     12    48   46     33      N        E     805      40
## -Alfredo Griffin      19   501  336    194      A        W     282     421
##                   Errors Salary NewLeague
## -Andy Allanson        20     NA         A
## -Alan Ashby           10  475.0         N
## -Alvin Davis          14  480.0         A
## -Andre Dawson          3  500.0         N
## -Andres Galarraga      4   91.5         N
## -Alfredo Griffin      25  750.0         A
#usuniecie wierszy z wartosciami NA 
Hitters=na.omit(Hitters)

# utworzenie macierzy na podstawie modelu, usunięcie pierwszej kolumny (wyraz  wolny Intercept = 1) 
X=model.matrix(Salary~., Hitters)[,-1]
Y=Hitters$Salary

Model regresji liniowej (OLS - Ordinary Least Square)

ols<-lm(Salary~., Hitters)
summary(ols)
## 
## Call:
## lm(formula = Salary ~ ., data = Hitters)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -907.62 -178.35  -31.11  139.09 1877.04 
## 
## Coefficients:
##               Estimate Std. Error t value Pr(>|t|)    
## (Intercept)  163.10359   90.77854   1.797 0.073622 .  
## AtBat         -1.97987    0.63398  -3.123 0.002008 ** 
## Hits           7.50077    2.37753   3.155 0.001808 ** 
## HmRun          4.33088    6.20145   0.698 0.485616    
## Runs          -2.37621    2.98076  -0.797 0.426122    
## RBI           -1.04496    2.60088  -0.402 0.688204    
## Walks          6.23129    1.82850   3.408 0.000766 ***
## Years         -3.48905   12.41219  -0.281 0.778874    
## CAtBat        -0.17134    0.13524  -1.267 0.206380    
## CHits          0.13399    0.67455   0.199 0.842713    
## CHmRun        -0.17286    1.61724  -0.107 0.914967    
## CRuns          1.45430    0.75046   1.938 0.053795 .  
## CRBI           0.80771    0.69262   1.166 0.244691    
## CWalks        -0.81157    0.32808  -2.474 0.014057 *  
## LeagueN       62.59942   79.26140   0.790 0.430424    
## DivisionW   -116.84925   40.36695  -2.895 0.004141 ** 
## PutOuts        0.28189    0.07744   3.640 0.000333 ***
## Assists        0.37107    0.22120   1.678 0.094723 .  
## Errors        -3.36076    4.39163  -0.765 0.444857    
## NewLeagueN   -24.76233   79.00263  -0.313 0.754218    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 315.6 on 243 degrees of freedom
## Multiple R-squared:  0.5461, Adjusted R-squared:  0.5106 
## F-statistic: 15.39 on 19 and 243 DF,  p-value: < 2.2e-16
coef(ols)
##  (Intercept)        AtBat         Hits        HmRun         Runs 
##  163.1035878   -1.9798729    7.5007675    4.3308829   -2.3762100 
##          RBI        Walks        Years       CAtBat        CHits 
##   -1.0449620    6.2312863   -3.4890543   -0.1713405    0.1339910 
##       CHmRun        CRuns         CRBI       CWalks      LeagueN 
##   -0.1728611    1.4543049    0.8077088   -0.8115709   62.5994230 
##    DivisionW      PutOuts      Assists       Errors   NewLeagueN 
## -116.8492456    0.2818925    0.3710692   -3.3607605  -24.7623251
#suma współczynników (bez wyrazu wolnego) 
theta_sum = sum(abs(coef(ols)[-1]))
#średni błąd kwadratowy
mse = mean((Y-predict(ols, Hitters))^2)

Suma współczynników (bez wyrazu wolnego) : 238.729529

Średni błąd kwadratowy (MSE - Mean Squared Error) : 9.201786910^{4}

L2

#Ridge regression 
library(glmnet)
## Loading required package: Matrix
## Loading required package: foreach
## Loaded glmnet 2.0-16
#alpha=0 oznacza, że L1 jest usuwane, z L2 wprowadzonae 
#alpha = 1 L2 usuwamy , L1 wprowadzone
#Funkcja jest wrażliwa na parametry thresh, zwiększenie tej 

fit_ridge <- glmnet(X,Y, alpha=0, lambda=100, thresh=10^-15)

theta_sum <- sum(abs(coef(fit_ridge)[-1]))
mse <- mean((Y-predict(fit_ridge, X))^2)

coef(fit_ridge)
## 20 x 1 sparse Matrix of class "dgCMatrix"
##                        s0
## (Intercept)  2.095504e+01
## AtBat       -1.151051e-01
## Hits         1.388104e+00
## HmRun       -7.056874e-01
## Runs         1.167516e+00
## RBI          8.491300e-01
## Walks        2.189345e+00
## Years       -2.818897e+00
## CAtBat       9.663782e-03
## CHits        8.240767e-02
## CHmRun       5.387731e-01
## CRuns        1.637887e-01
## CRBI         1.742471e-01
## CWalks      -4.205925e-02
## LeagueN      3.677668e+01
## DivisionW   -1.086633e+02
## PutOuts      2.275099e-01
## Assists      7.684556e-02
## Errors      -2.631419e+00
## NewLeagueN  -5.429521e-03

Suma współczynników (bez wyrazu wolnego) : 158.6258591

Średni błąd kwadratowy (MSE - Mean Squared Error) : 1.033300710^{5}

L1

#Lasso regression 
library(glmnet)

#alpha=0 oznacza, że L1 jest usuwane, z L2 wprowadzonae 
#alpha = 1 L2 usuwamy , L1 wprowadzone
#Funkcja jest wrażliwa na parametry thresh, zwiększenie tej 

fit_lasso <- glmnet(X,Y, alpha=1, lambda=10, thresh=10^-15)

theta_sum <- sum(abs(coef(fit_lasso)[-1]))
mse <- mean((Y-predict(fit_lasso, X))^2)

coef(fit_lasso)
## 20 x 1 sparse Matrix of class "dgCMatrix"
##                        s0
## (Intercept)   -1.32433735
## AtBat          .         
## Hits           2.00924028
## HmRun          .         
## Runs           .         
## RBI            .         
## Walks          2.25894158
## Years          .         
## CAtBat         .         
## CHits          .         
## CHmRun         0.02748497
## CRuns          0.21462840
## CRBI           0.41296526
## CWalks         .         
## LeagueN       18.72897616
## DivisionW   -115.29332251
## PutOuts        0.23574254
## Assists        .         
## Errors        -0.78916974
## NewLeagueN     .

Suma współczynników (bez wyrazu wolnego) : 139.9704714

Średni błąd kwadratowy (MSE - Mean Squared Error) : 1.038401510^{5}

Dobór Lambda metodą sprawdzenia krzyżowego

par(mfrow = c(1, 2))
fit_ridge = glmnet(X, Y, alpha = 0)
plot(fit_ridge)
plot(fit_ridge, xvar = "lambda", label = TRUE)

Wykres błędu w zależności od wartości lambda. Wykorzystana funkcja sprawdzenia krzyżowego.

Na wykresie zaznaczone są dwie wartości \(\lambda\). Pierwsza z nich odpowiada minimalnej wartości błędu MSE.

Druga odpowiada wartości, przy której MSE znajduje się w odległości jednego odchylenia standardowego od błędu najmniejszego.

fit_ridge_cv = cv.glmnet(X, Y, alpha = 0)
plot(fit_ridge_cv)

# fitted coefficients, using minimum lambda
best_coef = coef(fit_ridge_cv, s = "lambda.min")

theta_sum <- sum(abs(best_coef[-1]))
mse <- mean((Y-predict(fit_ridge_cv, X))^2)

Suma współczynników (bez wyrazu wolnego) : 215.9064543

Średni błąd kwadratowy (MSE - Mean Squared Error) : 1.30404910^{5}

Wartość lambda:

library(broom)
glance(fit_ridge_cv) 
## # A tibble: 1 x 2
##   lambda.min lambda.1se
##        <dbl>      <dbl>
## 1       28.0      2437.