본 내용은 MIT 18.065 강의 14번 내용의 일부 정리입니다. 자세한 내용은 해당 강의를 참고하세요. 그리고 https://blog.naver.com/skkong89/221783800314 내용을 숙지하셔야, 아래 수식을 이해할 수 있습니다.
R markdown 에 수식 표현하기
\(Ax = b\) 에 대한 최소제곱 해를 구한다고 해보자. normal equation 을 사용하면 \(A^TA\hat{x} = A^Tb\) 을 풀어서 가장 비슷한 해인 \(\hat{x}\) 를 구해야 한다. 즉, \(\hat{x} = (A^TA)^{-1}b\) 를 구해야 하고 여기서 \((A^TA)^{-1}\) 가 필요하다. 그런데 \(A^TA\) 의 역행렬을 구하는 비용이 생각보다 크다.
새로운 데이터가 들어올때마다 normal equation 을 사용해서 최소제곱 문제를 풀려면, 매번 새로운 \(A_{new}^TA_{new}\) 의 역행렬을 계산해야 한다.
이 문제를 좀더 쉽게 해결할 수는 없을까? 최대한 기존 계산한 것을 재사용하고, nxn 사이즈의 전체 행렬에 대한 역행렬이 아니라, 그 보다 작은 사이즈인 kxk 의 역행렬을 구해서 문제를 해결할 수는 없을까?
여기서 Sherman-Morrison-Woodbury formula 가 사용된다.
\(M = A - UV^T\) 라면, \[M^{-1} = (A - UV^T)^{-1} = A^{-1} + A^{-1}U(I - V^TA^{-1}U)^{-1}V^TA^{-1}\]
M 은 nxn 사이즈의 행렬이고, U, V 는 nxk 사이즈의 행렬이다.
iris 예제를 통해서 행렬 A, b 를 구한다. 여기서 b 는 iris 타입에 대한 정수형 값이 된다.
library('caret')
options(digits = 2)
A <- iris # data 원본 유지
head(A) # 자료 일부 확인
## Sepal.Length Sepal.Width Petal.Length Petal.Width Species
## 1 5.1 3.5 1.4 0.2 setosa
## 2 4.9 3.0 1.4 0.2 setosa
## 3 4.7 3.2 1.3 0.2 setosa
## 4 4.6 3.1 1.5 0.2 setosa
## 5 5.0 3.6 1.4 0.2 setosa
## 6 5.4 3.9 1.7 0.4 setosa
dim(A) # 150x5
## [1] 150 5
# Species 타입을 숫자로 변환한다.
items <- unique(A$Species)
items <- as.vector(items)
species.type <- 1:length(items)
names(species.type) <- items
species.type
## setosa versicolor virginica
## 1 2 3
# b 는 우리가 추측하고자 하는 결과 변수가 된다.
b <- species.type[A$Species]
b <- as.matrix(b, byrow=T)
# A 에서 문자열 변수는 제외한다.
A <- subset(A, select = -Species)
# A를 행렬로 변환하고 앞에 몇 개를 살펴본다.
A <- as.matrix(A)
head(A)
## Sepal.Length Sepal.Width Petal.Length Petal.Width
## 1 5.1 3.5 1.4 0.2
## 2 4.9 3.0 1.4 0.2
## 3 4.7 3.2 1.3 0.2
## 4 4.6 3.1 1.5 0.2
## 5 5.0 3.6 1.4 0.2
## 6 5.4 3.9 1.7 0.4
이제 normal equation 으로 \(\hat{x}\) 를 찾아보자. 먼저 역행렬을 구한다.
AT <- t(A)
ATA <- AT %*% A # A^T A
ATA_inv <- solve(ATA) # A^T A 의 역행렬
ATA_inv
## Sepal.Length Sepal.Width Petal.Length Petal.Width
## Sepal.Length 0.050 -0.056 -0.046 0.045
## Sepal.Width -0.056 0.067 0.045 -0.037
## Petal.Length -0.046 0.045 0.067 -0.098
## Petal.Width 0.045 -0.037 -0.098 0.184
다음은 최적해이다.
# 우리가 찾는 최적해는 다음과 같이 구한다.
x_hat <- ATA_inv %*% AT %*% b
x_hat
## [,1]
## Sepal.Length 0.062
## Sepal.Width 0.065
## Petal.Length 0.205
## Petal.Width 0.549
# 최적해를 이용해서 기존 iris의 첫번째 데이터와 곱하면, 0.942 를 얻는다. 이것은 setosa 에 가까운 값이 된다.
sum(iris[1,1:4] * x_hat) # 0.942, setosa
## [1] 0.94
sum(iris[150, 1:4] * x_hat) # 2.59, virginica
## [1] 2.6
이제 새로운 데이터 v1, v2 가 들어왔을 때, \(\hat{x}_{new}\) 를 구해보자.
v1 <- c(5, 3.4, 1.3, 0.2) # iris[1,] 과 비슷하게
v2 <- c(5.8, 3.1, 5.0, 1.9) # iris[150,] 과 비슷하게
b1 <- 1
b2 <- 3
b_new <- matrix(c(b1, b2), nrow = 2, byrow = F)
dim(b_new)
## [1] 2 1
V <- matrix(c(v1, v2), nrow=4, byrow=F)
dim(V)
## [1] 4 2
\[\hat{x}_{new} = (A^TA + VV^T)^{-1} (A^Tb + Vb_{new})\]
위 식에서 중요한 것은 \((A^TA + VV^T)^{-1}\) 을 구해야 한다는 건데, 이것은 Sherman-Morrison-Woodbury formula 를 이용해서 빠르게 구할 수 있다는 것이, 이 실험의 핵심 내용이다.
이것을 빠르게 구할 수 있는 이유는?
\(M^{-1} = (A^TA + VV^T)^{-1}\) 을 먼저 구해보자.
U <- V # 여기서는 U, V 가 동일한 행렬이 된다.
VT <- t(V)
In <- diag(1, 2, 2) # kxk 단위행렬이 필요한데, 랭크2인 V 니까 2x2 로 생성한다.
# 주의사항: 공식에서는 (A - UVT)의 역행렬을 구하는거고, 새로운 normal equation 은
# ATA + VV^T 에 대한 역행렬을 구하기 때문에 아래 부호가 변경된다.
M_inv <- ATA_inv - ATA_inv %*% U %*% solve(In - VT %*% ATA_inv %*% U) %*% VT %*% ATA_inv
M_inv
## Sepal.Length Sepal.Width Petal.Length Petal.Width
## Sepal.Length 0.049 -0.055 -0.046 0.045
## Sepal.Width -0.055 0.066 0.044 -0.038
## Petal.Length -0.046 0.044 0.066 -0.098
## Petal.Width 0.045 -0.038 -0.098 0.184
최종적으로 \(\hat{x}_{new}\) 를 구해보자.
x_hat_new <- M_inv %*% (AT %*% b + V %*% b_new)
x_hat_new
## [,1]
## Sepal.Length 0.058
## Sepal.Width 0.071
## Petal.Length 0.207
## Petal.Width 0.550
A_new 는 새로운 데이터 포인트 2개가 추가된 행렬이라고 하자. 이 행렬 전체를 대상으로 normal equation 을 적용하고, 위에서 구한 값과 비교를 한다.
A_new <- rbind(A, t(V))
b_new1 <- rbind(b, b_new)
AT_new <- t(A_new)
ATA_new <- AT_new %*% A_new # A^T A
ATA_inv_new <- solve(ATA_new) # A^T A 의 역행렬
ATA_inv_new
## Sepal.Length Sepal.Width Petal.Length Petal.Width
## Sepal.Length 0.049 -0.055 -0.046 0.045
## Sepal.Width -0.055 0.066 0.044 -0.038
## Petal.Length -0.046 0.044 0.067 -0.098
## Petal.Width 0.045 -0.038 -0.098 0.184
# 우리가 찾는 최적해는 다음과 같이 구한다.
x_hat_new <- ATA_inv_new %*% AT_new %*% b_new1
x_hat_new
## [,1]
## Sepal.Length 0.053
## Sepal.Width 0.078
## Petal.Length 0.210
## Petal.Width 0.550