בפרק זה נעסוק באלגוריתם KNN שמשמש לסיווג ולרגרסיה. אלגוריתם זה פשוט יחסית להבנה ומהיר למימוש. הוא לא מתאים לסטטיסטיקה היסקית (אלא רק לחיזוי), והוא יכול להיות מאוד יעיל בבעיות בעלות מאפיינים מסוימים.
הרעיון שבבסיס שיטת KNN הוא שכאשר מגיעה תצפית חדשה, וצריך להעריך משתנה מטרה שלה או לסווג אותה לקטגוריה מסוימת, אז בוחנים לאילו תצפיות קיימות היא “קרובה” ומניחים שהערך שלה יהיה הסיווג הנפוץ בקרב התצפיות הקרובות (או ברגרסיה, יהיה ממוצע של ערך משתנה המטרה של השכנים הקרובים).
נמחיש את האלגוריתם באמצעות קובץ הנתונים של הפרחים iris: נניח שאנחנו רוצים לחזות סיווג של פרק מסוים למין שלו, על פי אורך ורוחב של עלי הכותרת שלו.
כלומר, יש לנו נתונים על אורך ורוחב של עלי כותרת של שלושת המינים. בסה“כ 50 תצפיות מכל מין.
כעת אנחנו מודדים שלושה פרחים חדשים (Length, Width):
וצריכים לסווג תצפיות אלו לכל אחד מהמינים (setosa, versicolor, virginica)
library(tidyverse)
iris %>%
count(Species)
## # A tibble: 3 x 2
## Species n
## <fct> <int>
## 1 setosa 50
## 2 versicolor 50
## 3 virginica 50
new_observations <- tribble(
~Sepal.Length, ~Sepal.Width, ~Species,
5, 4, "פרח א",
7.5, 3.1, "פרח ב",
6, 2.8, "פרח ג"
)
ggplot(iris,
aes(x = Sepal.Length, y = Sepal.Width, color = Species)) + geom_point() +
geom_point(data = new_observations, color = "black", size = 5, alpha = 0.7) +
geom_label(inherit.aes = FALSE,
data = new_observations,
aes(x = Sepal.Length,
y = Sepal.Width,
label = Species),
nudge_x=-0.3, show.legend = FALSE)
אינטואיטיבית אפשר לנחש שפרח א’ שייך לסיווג setosa, פרח ב’ שייך לסיווג virginica, ולגבי פרח ג’ זה לא כל כך ברור (או virginica או versicolor). האינטואיציה עבדה כך שהיא התאימה לכל נקודה חדשה את הנקודות שמסביבתה, וכך בדיוק פועל KNN.
למעשה באמצעות KNN ניתן לסווג כל אזור במרחב לאחד משלושת המינים. כל מה שצריך להחליט הוא על כמה נקודות להתבסס (וזה הפרמטר המרכזי של KNN - על כמה שכנים להתבסס).
הקוד הבא ממחיש כיצד מפעילים את knn, כולל שימוש בtrain/test sets:
# Split iris into train/test randomly
iris_split <- iris %>%
mutate(is_train = runif(nrow(iris)) <= 0.8)
## Warning: package 'bindrcpp' was built under R version 3.4.4
iris_train <- iris_split %>%
filter(is_train)
iris_test <- iris_split %>%
filter(!is_train)
# fit the KNN according to the train set
iris_knn <- class::knn(train = iris_train %>% select(Sepal.Length, Sepal.Width) %>% as.matrix(),
test = iris_test %>% select(Sepal.Length, Sepal.Width) %>% as.matrix(),
cl = iris_train %>% select(Species) %>% as.matrix(),
k = 1)
iris_test %>%
mutate(knn_class = iris_knn) %>%
group_by(knn_class, Species) %>%
tally() %>%
spread(key = Species, value = n, fill = 0)
## # A tibble: 3 x 4
## # Groups: knn_class [3]
## knn_class setosa versicolor virginica
## <fct> <dbl> <dbl> <dbl>
## 1 setosa 11 0 0
## 2 versicolor 0 8 3
## 3 virginica 0 3 9
iris_test %>%
mutate(classification_error = iris_knn != Species) %>%
group_by(Species) %>%
summarize(Species_class_error = mean(classification_error))
## # A tibble: 3 x 2
## Species Species_class_error
## <fct> <dbl>
## 1 setosa 0
## 2 versicolor 0.273
## 3 virginica 0.25
החישוב לעיל מציג את הטעות עבור ה-test set. מה תהיה הטעות עבור ה-train set במקרה של k=1?
כפי שצפוי, היכולת לסווג את המין setosa היא טובה, וב-test set הצלחנו לסווג את כל התצפיות נכונה. לעומת זאת, הסיווג של המינים versicolor ו-virginica הוא פחות מוצלח, קשה להפריד בין המינים הללו באמצעות שני המשתנים שהשתמשנו בהם.
למעשה, בהתבסס על שיטת knn אפשר לצייר “רשת” של סיווגים - כל נקודה במרחב תסווג לקטגוריה מסוימת, לפי התצפיות הסמוכות אליה.
# generate a full grid:
knn_grid <- expand.grid(Sepal.Length = seq(3, 8, 0.05), Sepal.Width = seq(2, 4.5, 0.05))
# show the result
head(knn_grid)
## Sepal.Length Sepal.Width
## 1 3.00 2
## 2 3.05 2
## 3 3.10 2
## 4 3.15 2
## 5 3.20 2
## 6 3.25 2
tail(knn_grid)
## Sepal.Length Sepal.Width
## 5146 7.75 4.5
## 5147 7.80 4.5
## 5148 7.85 4.5
## 5149 7.90 4.5
## 5150 7.95 4.5
## 5151 8.00 4.5
# Run the knn algorithm. Let's use the entire dataset.
# k = 1
iris_knn_1 <- class::knn(train = iris %>% select(Sepal.Length, Sepal.Width) %>% as.matrix(),
test = knn_grid,
cl = iris %>% select(Species) %>% as.matrix(),
k = 1)
# k = 2
iris_knn_2 <- class::knn(train = iris %>% select(Sepal.Length, Sepal.Width) %>% as.matrix(),
test = knn_grid,
cl = iris %>% select(Species) %>% as.matrix(),
k = 2)
# k = 10
iris_knn_10 <- class::knn(train = iris %>% select(Sepal.Length, Sepal.Width) %>% as.matrix(),
test = knn_grid,
cl = iris %>% select(Species) %>% as.matrix(),
k = 10)
# Now, I'm going to plot the resulting grids
knn_for_chart <- knn_grid %>%
bind_cols(tibble(`1nn` = iris_knn_1, `2nn` = iris_knn_2, `10nn` = iris_knn_10)) %>%
gather(key = "k", value = "classification", -Sepal.Length, -Sepal.Width) %>%
mutate(k = as_factor(k, levels = c("1nn", "2nn", "10nn")))
# to show you the data set I'm using to plot this after the gather operation
glimpse(knn_for_chart)
## Observations: 15,453
## Variables: 4
## $ Sepal.Length <dbl> 3.00, 3.05, 3.10, 3.15, 3.20, 3.25, 3.30, 3.35,...
## $ Sepal.Width <dbl> 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,...
## $ k <fct> 1nn, 1nn, 1nn, 1nn, 1nn, 1nn, 1nn, 1nn, 1nn, 1n...
## $ classification <chr> "setosa", "setosa", "setosa", "setosa", "setosa...
# the ggplot command
ggplot(knn_for_chart, aes(x = Sepal.Length, y = Sepal.Width, fill = classification)) +
geom_tile(alpha = 0.5) +
facet_wrap(~ k) +
geom_point(inherit.aes = F, data = iris, aes(x = Sepal.Length, y = Sepal.Width, fill = Species),
color = "black", pch = 21, size = 3)
אחד מהכלים החשובים שיש להפעיל, לפני שמשתמשים באלגוריתם של knn הוא מרכוז ו/או התאמה של קנה המידה של הנתונים.
חשבו על שני משתנים, שאחד נע בין 0-100 והשני נע בין 0-1. ההגדרה של “שכונה” מושפעת מאוד מהמרחקים הללו, כך שהמשתנה של 0-100 משפיע בצורה משמעותית הרבה יותר על מציאת השכנים הקרובים. לעיתים ביצוע מרכוז ו/או התאמת קנה המידה של הנתונים ישפר את החיזוי.
mutate ובפונקציה scale כדי ליצור ארבעה משתנים חדשים (גם ב-train set וגם ב-test set), והתאימו מודל חדש המבוסס על משתנים אלו. חשבו את שגיאת הסיווג בכל אחד מהמינים. האם שינוי קנה המידה סייע בחיזוי?לעיתים כדי להגיע למודל טוב נדרש הרבה ניסוי וטעיה, ולפעמים גם מחשבה מחוץ לקופסה (לגבי איך להמיר משתנים שונים).
למודל ה-knn ישנן מספר בעיות:
בפרק הבא נדון במודלים של רגרסיה - מודלים שונים במהותם ממודל ה-knn, פשוטים יחסית, אך יעילים להפליא.