Summary

본 내용은 MIT 18.065 강의 14번 내용의 일부 정리입니다. 자세한 내용은 해당 강의를 참고하세요. 그리고 https://blog.naver.com/skkong89/221783800314 내용을 숙지하셔야, 아래 수식을 이해할 수 있습니다.

R markdown 에 수식 표현하기

\(Ax = b\) 에 대한 최소제곱 해를 구한다고 해보자. normal equation 을 사용하면 \(A^TA\hat{x} = A^Tb\) 을 풀어서 가장 비슷한 해인 \(\hat{x}\) 를 구해야 한다. 즉, \(\hat{x} = (A^TA)^{-1}b\) 를 구해야 하고 여기서 \((A^TA)^{-1}\) 가 필요하다. 그런데 \(A^TA\) 의 역행렬을 구하는 비용이 생각보다 크다.

새로운 데이터가 들어올때마다 normal equation 을 사용해서 최소제곱 문제를 풀려면, 매번 새로운 \(A_{new}^TA_{new}\) 의 역행렬을 계산해야 한다.

이 문제를 좀더 쉽게 해결할 수는 없을까? 최대한 기존 계산한 것을 재사용하고, nxn 사이즈의 전체 행렬에 대한 역행렬이 아니라, 그 보다 작은 사이즈인 kxk 의 역행렬을 구해서 문제를 해결할 수는 없을까?

여기서 Sherman-Morrison-Woodbury formula 가 사용된다.

\(M = A - UV^T\) 라면, \[M^{-1} = (A - UV^T)^{-1} = A^{-1} + A^{-1}U(I - V^TA^{-1}U)^{-1}V^TA^{-1}\]

M 은 nxn 사이즈의 행렬이고, U, V 는 nxk 사이즈의 행렬이다.

예제 풀이

1차 normal equation 적용

iris 예제를 통해서 행렬 A, b 를 구한다. 여기서 b 는 iris 타입에 대한 정수형 값이 된다.

library('caret')
options(digits = 2)

A <- iris # data 원본 유지
head(A) # 자료 일부 확인
##   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
## 1          5.1         3.5          1.4         0.2  setosa
## 2          4.9         3.0          1.4         0.2  setosa
## 3          4.7         3.2          1.3         0.2  setosa
## 4          4.6         3.1          1.5         0.2  setosa
## 5          5.0         3.6          1.4         0.2  setosa
## 6          5.4         3.9          1.7         0.4  setosa
dim(A) # 150x5
## [1] 150   5
# Species 타입을 숫자로 변환한다.
items <- unique(A$Species)
items <- as.vector(items)
species.type <- 1:length(items)
names(species.type) <- items
species.type
##     setosa versicolor  virginica 
##          1          2          3
# b 는 우리가 추측하고자 하는 결과 변수가 된다.
b <- species.type[A$Species]
b <- as.matrix(b, byrow=T)

# A 에서 문자열 변수는 제외한다.
A <- subset(A, select = -Species)

# A를 행렬로 변환하고 앞에 몇 개를 살펴본다.
A <- as.matrix(A)
head(A)
##   Sepal.Length Sepal.Width Petal.Length Petal.Width
## 1          5.1         3.5          1.4         0.2
## 2          4.9         3.0          1.4         0.2
## 3          4.7         3.2          1.3         0.2
## 4          4.6         3.1          1.5         0.2
## 5          5.0         3.6          1.4         0.2
## 6          5.4         3.9          1.7         0.4

이제 normal equation 으로 \(\hat{x}\) 를 찾아보자. 먼저 역행렬을 구한다.

AT <- t(A)
ATA <- AT %*% A # A^T A 
ATA_inv <- solve(ATA) # A^T A 의 역행렬
ATA_inv
##              Sepal.Length Sepal.Width Petal.Length Petal.Width
## Sepal.Length        0.050      -0.056       -0.046       0.045
## Sepal.Width        -0.056       0.067        0.045      -0.037
## Petal.Length       -0.046       0.045        0.067      -0.098
## Petal.Width         0.045      -0.037       -0.098       0.184

다음은 최적해이다.

# 우리가 찾는 최적해는 다음과 같이 구한다.
x_hat <- ATA_inv %*% AT %*% b
x_hat
##               [,1]
## Sepal.Length 0.062
## Sepal.Width  0.065
## Petal.Length 0.205
## Petal.Width  0.549
# 최적해를 이용해서 기존 iris의 첫번째 데이터와 곱하면, 0.942 를 얻는다. 이것은 setosa 에 가까운 값이 된다.
sum(iris[1,1:4] * x_hat) # 0.942, setosa
## [1] 0.94
sum(iris[150, 1:4] * x_hat) # 2.59, virginica
## [1] 2.6

2차 normal equation 적용

이제 새로운 데이터 v1, v2 가 들어왔을 때, \(\hat{x}_{new}\) 를 구해보자.

v1 <- c(5, 3.4, 1.3, 0.2) # iris[1,] 과 비슷하게
v2 <- c(5.8, 3.1, 5.0, 1.9) # iris[150,] 과 비슷하게
b1 <- 1
b2 <- 3
b_new <- matrix(c(b1, b2), nrow = 2, byrow = F)
dim(b_new)
## [1] 2 1
V <- matrix(c(v1, v2), nrow=4, byrow=F)
dim(V)
## [1] 4 2

\[\hat{x}_{new} = (A^TA + VV^T)^{-1} (A^Tb + Vb_{new})\]

위 식에서 중요한 것은 \((A^TA + VV^T)^{-1}\) 을 구해야 한다는 건데, 이것은 Sherman-Morrison-Woodbury formula 를 이용해서 빠르게 구할 수 있다는 것이, 이 실험의 핵심 내용이다.

이것을 빠르게 구할 수 있는 이유는?

  • 중간에 역행렬을 구하는 공식이 2x2의 역행렬 구하는 문제로 바뀌었고,
  • 위에서 구한 \(A^TA\) 의 역행렬을 그대로 사용하기 때문이다.

\(M^{-1} = (A^TA + VV^T)^{-1}\) 을 먼저 구해보자.

U <- V # 여기서는 U, V 가 동일한 행렬이 된다.
VT <- t(V)
In <- diag(1, 2, 2) # kxk 단위행렬이 필요한데, 랭크2인 V 니까 2x2 로 생성한다.

# 주의사항: 공식에서는 (A - UVT)의 역행렬을 구하는거고, 새로운 normal equation 은
# ATA + VV^T 에 대한 역행렬을 구하기 때문에 아래 부호가 변경된다.
M_inv <- ATA_inv - ATA_inv %*% U %*% solve(In - VT %*% ATA_inv %*% U) %*% VT %*% ATA_inv
M_inv
##              Sepal.Length Sepal.Width Petal.Length Petal.Width
## Sepal.Length        0.049      -0.055       -0.046       0.045
## Sepal.Width        -0.055       0.066        0.044      -0.038
## Petal.Length       -0.046       0.044        0.066      -0.098
## Petal.Width         0.045      -0.038       -0.098       0.184

최종적으로 \(\hat{x}_{new}\) 를 구해보자.

x_hat_new <- M_inv %*% (AT %*% b + V %*% b_new)
x_hat_new
##               [,1]
## Sepal.Length 0.058
## Sepal.Width  0.071
## Petal.Length 0.207
## Petal.Width  0.550

검증작업

A_new 는 새로운 데이터 포인트 2개가 추가된 행렬이라고 하자. 이 행렬 전체를 대상으로 normal equation 을 적용하고, 위에서 구한 값과 비교를 한다.

A_new <- rbind(A, t(V))
b_new1 <- rbind(b, b_new)
AT_new <- t(A_new)
ATA_new <- AT_new %*% A_new # A^T A 
ATA_inv_new <- solve(ATA_new) # A^T A 의 역행렬
ATA_inv_new
##              Sepal.Length Sepal.Width Petal.Length Petal.Width
## Sepal.Length        0.049      -0.055       -0.046       0.045
## Sepal.Width        -0.055       0.066        0.044      -0.038
## Petal.Length       -0.046       0.044        0.067      -0.098
## Petal.Width         0.045      -0.038       -0.098       0.184
# 우리가 찾는 최적해는 다음과 같이 구한다.
x_hat_new <- ATA_inv_new %*% AT_new %*% b_new1
x_hat_new
##               [,1]
## Sepal.Length 0.053
## Sepal.Width  0.078
## Petal.Length 0.210
## Petal.Width  0.550