data("iris")
head(iris)
##   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
## 1          5.1         3.5          1.4         0.2  setosa
## 2          4.9         3.0          1.4         0.2  setosa
## 3          4.7         3.2          1.3         0.2  setosa
## 4          4.6         3.1          1.5         0.2  setosa
## 5          5.0         3.6          1.4         0.2  setosa
## 6          5.4         3.9          1.7         0.4  setosa
# Load built-in iris dataset
data("iris")

Let’s plot counts of our Species categorical variable

ggplot(data=iris, aes(x=Species)) + geom_bar()

Data Wrangling

iris$species <- as.factor(iris$Species)
# Breaking out training and test datasets
split <- initial_split(iris, prop = 0.8, strata = species)
train <- split %>% 
         training()
test <- split %>% 
        testing()

Modeling

Let’s use the logisitc_reg model that comes from tidymodels

model <- multinom_reg() %>% fit(species ~ Sepal.Length + Petal.Width, train)
tidy(model)
## # A tibble: 6 × 6
##   y.level    term         estimate std.error statistic p.value
##   <chr>      <chr>           <dbl>     <dbl>     <dbl>   <dbl>
## 1 versicolor (Intercept)    -0.872     25.3    -0.0344   0.973
## 2 versicolor Sepal.Length   -4.34       5.51   -0.788    0.431
## 3 versicolor Petal.Width    33.6       68.2     0.493    0.622
## 4 virginica  (Intercept)   -23.7       25.6    -0.929    0.353
## 5 virginica  Sepal.Length   -4.01       5.50   -0.730    0.466
## 6 virginica  Petal.Width    46.3       68.3     0.678    0.498

Model Evaluation

# Predict class on test data
pred_class <- predict(model,
                      new_data = test,
                      type = "class")
# predict relative class probabilities
pred_prob <- predict(model,
                      new_data = test,
                      type = "prob")
results <- test %>%
           select(species) %>%
           bind_cols(pred_class, pred_prob)

results$match <- ifelse(results$.pred_class == results$species, TRUE, FALSE)
ggplot(results, aes(x=match)) + geom_bar() + labs(x="Prediction Matches Actual Class", y="Count of Flowers", title="Multinomial Regression Results, Iris Dataset")

(nrow(results %>% filter(match==TRUE)) / nrow(results) )
## [1] 0.9666667

We see an accuracy of 100%, which we should take with a grain of salt given the small size of our dataset