Code
# ================================
# Multinomial Logistic Regression + Full Visualization
# ================================
options(width=50)
# Load libraries
library(nnet)
library(ggplot2)
library(dplyr)
library(GGally)

# Load data
data(iris)

# -------------------------------
# Cross Plot (Pairwise Scatter Matrix)
# -------------------------------
p_cross <- ggpairs(
  iris,
  columns = 1:4,
  aes(color = Species, alpha = 0.7),
  upper = list(continuous = wrap("points", size = 1.5)),
  lower = list(continuous = wrap("smooth", method = "loess", se = FALSE)),
  diag  = list(continuous = "densityDiag")
) +
  theme_minimal()

print(p_cross)

Code
# -------------------------------
# Fit multinomial logistic model
# -------------------------------
model <- multinom(Species ~ Sepal.Length + Sepal.Width + 
                    Petal.Length + Petal.Width,
                  data = iris)
# weights:  18 (10 variable)
initial  value 164.791843 
iter  10 value 16.177348
iter  20 value 7.111438
iter  30 value 6.182999
iter  40 value 5.984028
iter  50 value 5.961278
iter  60 value 5.954900
iter  70 value 5.951851
iter  80 value 5.950343
iter  90 value 5.949904
iter 100 value 5.949867
final  value 5.949867 
stopped after 100 iterations
Code
# -------------------------------
# Predictions
# -------------------------------
iris$Predicted <- predict(model)
iris$Prob_Max <- apply(predict(model, type = "probs"), 1, max)

# -------------------------------
# Accuracy
# -------------------------------
accuracy <- mean(iris$Predicted == iris$Species)
print(paste("Accuracy:", round(accuracy, 4)))
[1] "Accuracy: 0.9867"
Code
print(table(Predicted = iris$Predicted, Actual = iris$Species))
            Actual
Predicted    setosa versicolor virginica
  setosa         50          0         0
  versicolor      0         49         1
  virginica       0          1        49
Code
# -------------------------------
# Visualization 1: True Classes
# -------------------------------
p1 <- ggplot(iris, aes(Petal.Length, Petal.Width, color = Species)) +
  geom_point(size = 2) +
  labs(title = "Iris Data (True Classes)") +
  theme_minimal()

print(p1)

Code
# -------------------------------
# Visualization 2: Predicted Classes
# -------------------------------
p2 <- ggplot(iris, aes(Petal.Length, Petal.Width, color = Predicted)) +
  geom_point(size = 2) +
  labs(title = "Predicted Classes (Multinomial Logistic)") +
  theme_minimal()

print(p2)

Code
# -------------------------------
# Visualization 3: Decision Boundary
# -------------------------------
grid <- expand.grid(
  Petal.Length = seq(min(iris$Petal.Length), max(iris$Petal.Length), length = 200),
  Petal.Width  = seq(min(iris$Petal.Width),  max(iris$Petal.Width),  length = 200)
)

# Fix other variables at mean
grid$Sepal.Length <- mean(iris$Sepal.Length)
grid$Sepal.Width  <- mean(iris$Sepal.Width)

# Predict on grid
grid$Predicted <- predict(model, newdata = grid)

# Plot decision regions
p3 <- ggplot() +
  geom_tile(data = grid, aes(Petal.Length, Petal.Width, fill = Predicted), alpha = 0.3) +
  geom_point(data = iris, aes(Petal.Length, Petal.Width, color = Species), size = 1.5) +
  labs(title = "Decision Regions (Softmax Classifier)") +
  theme_minimal()

print(p3)