library(tidyverse)
## -- Attaching packages -------------------------------------- tidyverse 1.2.1 --
## v ggplot2 3.2.1 v purrr 0.3.2
## v tibble 2.1.3 v dplyr 0.8.3
## v tidyr 1.0.0 v stringr 1.4.0
## v readr 1.3.1 v forcats 0.4.0
## -- Conflicts ----------------------------------------- tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
library(caret)
## Loading required package: lattice
##
## Attaching package: 'caret'
## The following object is masked from 'package:purrr':
##
## lift
iris <- datasets::iris
Split
split <- createDataPartition(iris$Species, p = 0.7, list = F)
train <- iris[split,]
test <- iris[-split,]
Fit
fit <- train(Species ~., method = "knn", data = train)
Predict training
prediction <- predict(fit, train)
Nest
nested <- iris %>%
group_by(Species) %>%
nest()
Write function to apply to each nested df where fit is our trained knn model
applymodel <- function(df) {
predict(fit, df)
}
apply and view predictions
nested %>%
mutate(model = map(data,applymodel)) %>%
unnest(model) %>%
view()