Strategic Operational Insights through Airline Customer Segmentation


Aim: To perform customer segmentation based on flight booking data and derive actionable insights into customer behaviors and preferences to support strategic decision-making in the airline and travel industries.

Research Questions (RQ):

  1. What are the key characteristics and behaviours of customers in different Flight Booking Segments?
  2. How can Customer Segmentation be used to Identify distinct preference and travel patterns?
  3. How can customer segmentation improve operational efficiency in the airline industry?
  4. Which Patterns influence the characteristics of airline reservations?

Data

The dataset used for this analysis is available on Kaggle: Flight Price Prediction

Data Explanation

Variable Name Features
Airline The name of the Airline
Flight Information regarding the plane’s flight code
Source City City from which the flight takes off
Departure Time Time of plane departs
Stops The number of stops between the source and destination
Arrival Time Time of plane arrives
Destination City City where the flight will land
Class Seat class (Economy/ Business)
Duration Overall amount of time it takes to travel between cities
Days Left This is a derived characteristic that is calculated by subtracting the trip date by the booking date.
Price Ticket price

Importing Libraries

library(readr)
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(tidyr)
library(ggplot2)
library(scales)
## 
## Attaching package: 'scales'
## The following object is masked from 'package:readr':
## 
##     col_factor
library(reshape2)
## 
## Attaching package: 'reshape2'
## The following object is masked from 'package:tidyr':
## 
##     smiths
library(RColorBrewer)
library(forcats)
library(sf)
## Linking to GEOS 3.11.0, GDAL 3.5.3, PROJ 9.1.0; sf_use_s2() is TRUE
library(leaflet)
library(gridExtra)
## 
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
## 
##     combine
library(caret)
## Loading required package: lattice
library(viridis)
## Loading required package: viridisLite
## 
## Attaching package: 'viridis'
## The following object is masked from 'package:scales':
## 
##     viridis_pal
library(cluster)
library(dbscan)
## 
## Attaching package: 'dbscan'
## The following object is masked from 'package:stats':
## 
##     as.dendrogram
library(pheatmap)

Data Preprocessing

Clean_Dataset <- read_csv("Clean_Dataset.csv")
## New names:
## Rows: 300153 Columns: 12
## ── Column specification
## ──────────────────────────────────────────────────────── Delimiter: "," chr
## (8): airline, flight, source_city, departure_time, stops, arrival_time, ... dbl
## (4): ...1, duration, days_left, price
## ℹ Use `spec()` to retrieve the full column specification for this data. ℹ
## Specify the column types or set `show_col_types = FALSE` to quiet this message.
## • `` -> `...1`
dim(Clean_Dataset)
## [1] 300153     12
str(Clean_Dataset)
## spc_tbl_ [300,153 × 12] (S3: spec_tbl_df/tbl_df/tbl/data.frame)
##  $ ...1            : num [1:300153] 0 1 2 3 4 5 6 7 8 9 ...
##  $ airline         : chr [1:300153] "SpiceJet" "SpiceJet" "AirAsia" "Vistara" ...
##  $ flight          : chr [1:300153] "SG-8709" "SG-8157" "I5-764" "UK-995" ...
##  $ source_city     : chr [1:300153] "Delhi" "Delhi" "Delhi" "Delhi" ...
##  $ departure_time  : chr [1:300153] "Evening" "Early_Morning" "Early_Morning" "Morning" ...
##  $ stops           : chr [1:300153] "zero" "zero" "zero" "zero" ...
##  $ arrival_time    : chr [1:300153] "Night" "Morning" "Early_Morning" "Afternoon" ...
##  $ destination_city: chr [1:300153] "Mumbai" "Mumbai" "Mumbai" "Mumbai" ...
##  $ class           : chr [1:300153] "Economy" "Economy" "Economy" "Economy" ...
##  $ duration        : num [1:300153] 2.17 2.33 2.17 2.25 2.33 2.33 2.08 2.17 2.17 2.25 ...
##  $ days_left       : num [1:300153] 1 1 1 1 1 1 1 1 1 1 ...
##  $ price           : num [1:300153] 5953 5953 5956 5955 5955 ...
##  - attr(*, "spec")=
##   .. cols(
##   ..   ...1 = col_double(),
##   ..   airline = col_character(),
##   ..   flight = col_character(),
##   ..   source_city = col_character(),
##   ..   departure_time = col_character(),
##   ..   stops = col_character(),
##   ..   arrival_time = col_character(),
##   ..   destination_city = col_character(),
##   ..   class = col_character(),
##   ..   duration = col_double(),
##   ..   days_left = col_double(),
##   ..   price = col_double()
##   .. )
##  - attr(*, "problems")=<externalptr>
sapply(Clean_Dataset, class)
##             ...1          airline           flight      source_city 
##        "numeric"      "character"      "character"      "character" 
##   departure_time            stops     arrival_time destination_city 
##      "character"      "character"      "character"      "character" 
##            class         duration        days_left            price 
##      "character"        "numeric"        "numeric"        "numeric"
summary(Clean_Dataset)
##       ...1          airline             flight          source_city       
##  Min.   :     0   Length:300153      Length:300153      Length:300153     
##  1st Qu.: 75038   Class :character   Class :character   Class :character  
##  Median :150076   Mode  :character   Mode  :character   Mode  :character  
##  Mean   :150076                                                           
##  3rd Qu.:225114                                                           
##  Max.   :300152                                                           
##  departure_time        stops           arrival_time       destination_city  
##  Length:300153      Length:300153      Length:300153      Length:300153     
##  Class :character   Class :character   Class :character   Class :character  
##  Mode  :character   Mode  :character   Mode  :character   Mode  :character  
##                                                                             
##                                                                             
##                                                                             
##     class              duration       days_left      price       
##  Length:300153      Min.   : 0.83   Min.   : 1   Min.   :  1105  
##  Class :character   1st Qu.: 6.83   1st Qu.:15   1st Qu.:  4783  
##  Mode  :character   Median :11.25   Median :26   Median :  7425  
##                     Mean   :12.22   Mean   :26   Mean   : 20890  
##                     3rd Qu.:16.17   3rd Qu.:38   3rd Qu.: 42521  
##                     Max.   :49.83   Max.   :49   Max.   :123071
any(is.na(Clean_Dataset))
## [1] FALSE
any(duplicated(Clean_Dataset))
## [1] FALSE
# Prevent scientific notation
options(scipen = 999)

# Adjust margins to fit axis labels
par(mfrow = c(1, 3), mar = c(5, 6, 4, 2)) # Add space to the left for y-axis labels

# Create the first boxplot: price
ticks_price <- pretty(Clean_Dataset$price, n = 5) # Generate fewer tick marks
boxplot(Clean_Dataset$price,
        main = "Price",
        col = "lightblue",
        border = "red",
        outline = TRUE,
        cex.axis = 0.8,
        yaxt = "n") # Suppress default y-axis
# Add custom y-axis with comma formatting
axis(2, at = ticks_price, labels = comma(ticks_price), las = 1) # Rotate labels for clarity

# Create the second boxplot: duration
ticks_duration <- pretty(Clean_Dataset$duration, n = 5)
boxplot(Clean_Dataset$duration,
        main = "Duration",
        col = "lightblue",
        border = "red",
        outline = TRUE,
        cex.axis = 0.8,
        yaxt = "n") # Suppress default y-axis
# Add custom y-axis with comma formatting
axis(2, at = ticks_duration, labels = comma(ticks_duration), las = 1)

# Create the third boxplot: days left
ticks_days_left <- pretty(Clean_Dataset$days_left, n = 5)
boxplot(Clean_Dataset$days_left,
        main = "Days Left",
        col = "lightblue",
        border = "red",
        outline = TRUE,
        cex.axis = 0.8,
        yaxt = "n") # Suppress default y-axis
# Add custom y-axis with comma formatting
axis(2, at = ticks_days_left, labels = comma(ticks_days_left), las = 1)

# Function to detect outliers
detect_outliers <- function(data, column) {
  q1 <- quantile(data[[column]], 0.25)
  q3 <- quantile(data[[column]], 0.75)
  iqr <- q3 - q1
  lower_bound <- max(0, q1 - 1.5 * iqr)
  upper_bound <- q3 + 1.5 * iqr
  outliers <- data %>% filter(!!sym(column) < lower_bound | !!sym(column) > upper_bound)
  return(list(lower_bound = lower_bound, upper_bound = upper_bound, outliers = outliers))
}

# Detect outliers for each column
price_outliers <- detect_outliers(Clean_Dataset, "price")
duration_outliers <- detect_outliers(Clean_Dataset, "duration")
days_left_outliers <- detect_outliers(Clean_Dataset, "days_left")

# Print the results
print(paste0("Number of outliers in price: ", nrow(price_outliers$outliers)))
## [1] "Number of outliers in price: 123"
print(paste0("Lower bound for price: ", price_outliers$lower_bound))
## [1] "Lower bound for price: 0"
print(paste0("Upper bound for price: ", price_outliers$upper_bound))
## [1] "Upper bound for price: 99128"
print(paste0("Number of outliers in duration: ", nrow(duration_outliers$outliers)))
## [1] "Number of outliers in duration: 2110"
print(paste0("Lower bound for duration: ", duration_outliers$lower_bound))
## [1] "Lower bound for duration: 0"
print(paste0("Upper bound for duration: ", duration_outliers$upper_bound))
## [1] "Upper bound for duration: 30.18"
print(paste0("Number of outliers in days_left: ", nrow(days_left_outliers$outliers)))
## [1] "Number of outliers in days_left: 0"
print(paste0("Lower bound for days_left: ", days_left_outliers$lower_bound))
## [1] "Lower bound for days_left: 0"
print(paste0("Upper bound for days_left: ", days_left_outliers$upper_bound))
## [1] "Upper bound for days_left: 72.5"
# Function to remove outliers
remove_outliers <- function(data, column, lower_bound, upper_bound) {
  data %>%
    filter(between(!!sym(column), lower_bound, upper_bound))
}

# Remove outliers from 'price' and 'duration'
remove_outliers <- function(data) {
  price_lower <- 0  # Adjust lower bound as needed
  price_upper <- 99128
  duration_lower <- 0  # Adjust lower bound as needed
  duration_upper <- 30.18
  
  data %>%
    filter(between(price, price_lower, price_upper),
           between(duration, duration_lower, duration_upper))
}

Clean_Dataset <- remove_outliers(Clean_Dataset)

# Print the number of rows after outlier removal
print(paste0("Number of rows after outlier removal: ", nrow(Clean_Dataset)))
## [1] "Number of rows after outlier removal: 297920"
# Set up a 1x3 layout for the plots
par(mfrow = c(1, 3))

# Create the first boxplot: price
boxplot(Clean_Dataset$price,
        main = "Price",
        col = "lightblue",
        border = "red",
        outline = TRUE,
        cex.axis = 0.8,
        yaxt = "n") # Suppress default y-axis
# Add custom y-axis with comma formatting
axis(2, at = pretty(Clean_Dataset$price), labels = comma(pretty(Clean_Dataset$price)))

# Create the second boxplot: duration
boxplot(Clean_Dataset$duration,
        main = "Duration",
        col = "lightblue",
        border = "red",
        outline = TRUE,
        cex.axis = 0.8,
        yaxt = "n") # Suppress default y-axis
# Add custom y-axis with comma formatting
axis(2, at = pretty(Clean_Dataset$duration), labels = comma(pretty(Clean_Dataset$duration)))

# Create the third boxplot: days left
boxplot(Clean_Dataset$days_left,
        main = "Days Left",
        col = "lightblue",
        border = "red",
        outline = TRUE,
        cex.axis = 0.8,
        yaxt = "n") # Suppress default y-axis
# Add custom y-axis with comma formatting
axis(2, at = pretty(Clean_Dataset$days_left), labels = comma(pretty(Clean_Dataset$days_left)))

head(Clean_Dataset)
# Sample data frame with missing values and duplicates
Clean_Dataset <- Clean_Dataset

# Count missing values
total_missing <- sum(is.na(Clean_Dataset))
print(total_missing)
## [1] 0
# Count duplicate rows
total_duplicates <- sum(duplicated(Clean_Dataset))
print(total_duplicates)
## [1] 0
duplicate_rows <- Clean_Dataset[duplicated(Clean_Dataset), ]
print(duplicate_rows)
## # A tibble: 0 × 12
## # ℹ 12 variables: ...1 <dbl>, airline <chr>, flight <chr>, source_city <chr>,
## #   departure_time <chr>, stops <chr>, arrival_time <chr>,
## #   destination_city <chr>, class <chr>, duration <dbl>, days_left <dbl>,
## #   price <dbl>
str(duplicate_rows)
## tibble [0 × 12] (S3: tbl_df/tbl/data.frame)
##  $ ...1            : num(0) 
##  $ airline         : chr(0) 
##  $ flight          : chr(0) 
##  $ source_city     : chr(0) 
##  $ departure_time  : chr(0) 
##  $ stops           : chr(0) 
##  $ arrival_time    : chr(0) 
##  $ destination_city: chr(0) 
##  $ class           : chr(0) 
##  $ duration        : num(0) 
##  $ days_left       : num(0) 
##  $ price           : num(0)
# Calculate appropriate binwidths using Freedman-Diaconis rule
binwidth_price <- 2 * IQR(Clean_Dataset$price) / length(Clean_Dataset$price)^(1/3)
binwidth_duration <- 1  # Since we want integer breaks for duration
binwidth_days <- 1      # Since we want integer breaks for days_left

# Plot for Price
p1 <- ggplot(Clean_Dataset, aes(x = price)) +
  geom_histogram(aes(y = after_stat(density)), fill = "lightblue", color = "black",
                binwidth = binwidth_price) +
  geom_density(color = "red") +
  geom_vline(aes(xintercept = mean(price), linetype = "Mean"), 
             color = "black", linewidth = 0.5) +
  geom_vline(aes(xintercept = mean(price) + sd(price), 
                 linetype = "Mean ± 1 SD"),
             color = "yellow", linewidth = 0.5) +
  geom_vline(aes(xintercept = mean(price) - sd(price)),
             linetype = "dashed", color = "yellow", linewidth = 0.5) +
  labs(title = "Price") +
  scale_linetype_manual(name = "", 
                       values = c("Mean" = "dashed", "Mean ± 1 SD" = "dashed")) +
  theme_minimal()

# Plot for Duration
p2 <- ggplot(Clean_Dataset, aes(x = duration)) +
  geom_histogram(aes(y = after_stat(density)), fill = "lightblue", color = "black",
                binwidth = binwidth_duration) +
  geom_density(color = "red") +
  geom_vline(aes(xintercept = mean(duration), linetype = "Mean"), 
             color = "black", linewidth = 0.5) +
  geom_vline(aes(xintercept = mean(duration) + sd(duration), 
                 linetype = "Mean ± 1 SD"),
             color = "yellow", linewidth = 0.5) +
  geom_vline(aes(xintercept = mean(duration) - sd(duration)),
             linetype = "dashed", color = "yellow", linewidth = 0.5) +
  labs(title = "Duration") +
  scale_linetype_manual(name = "", 
                       values = c("Mean" = "dashed", "Mean ± 1 SD" = "dashed")) +
  theme_minimal()

# Plot for Days Left
p3 <- ggplot(Clean_Dataset, aes(x = days_left)) +
  geom_histogram(aes(y = after_stat(density)), fill = "lightblue", color = "black",
                binwidth = binwidth_days) +
  geom_density(color = "red") +
  geom_vline(aes(xintercept = mean(days_left), linetype = "Mean"), 
             color = "black", linewidth = 0.5) +
  geom_vline(aes(xintercept = mean(days_left) + sd(days_left), 
                 linetype = "Mean ± 1 SD"),
             color = "yellow", linewidth = 0.5) +
  geom_vline(aes(xintercept = mean(days_left) - sd(days_left)),
             linetype = "dashed", color = "yellow", linewidth = 0.5) +
  labs(title = "Days Left") +
  scale_linetype_manual(name = "", 
                       values = c("Mean" = "dashed", "Mean ± 1 SD" = "dashed")) +
  theme_minimal()

# Combine all plots
grid.arrange(p1, p2, p3, ncol = 1, heights = c(1, 1, 1))

# Data transformation
Clean_Dataset1 <- Clean_Dataset %>%
  mutate(price_log = log1p(price),
         duration_scaled = scale(duration),
         days_left_scaled = rescale(days_left))
head(Clean_Dataset1)
# Calculate appropriate binwidths using Freedman-Diaconis rule
binwidth_price <- 2 * IQR(Clean_Dataset1$price_log) / length(Clean_Dataset1$price_log)^(1/3)
binwidth_duration <- 2 * IQR(Clean_Dataset1$duration_scaled) / length(Clean_Dataset1$duration_scaled)^(1/3)
binwidth_days <- 2 * IQR(Clean_Dataset1$days_left_scaled) / length(Clean_Dataset1$days_left_scaled)^(1/3)

# Plot for Log-transformed Price
p1 <- ggplot(Clean_Dataset1, aes(x = price_log)) +
  geom_histogram(aes(y = after_stat(density)), fill = "lightblue", color = "black",
                binwidth = binwidth_price) +
  geom_density(color = "red") +
  geom_vline(aes(xintercept = mean(price_log), linetype = "Mean"), 
             color = "black", linewidth = 0.5) +
  geom_vline(aes(xintercept = mean(price_log) + sd(price_log), 
                 linetype = "Mean ± 1 SD"),
             color = "yellow", linewidth = 0.5) +
  geom_vline(aes(xintercept = mean(price_log) - sd(price_log)),
             linetype = "dashed", color = "yellow", linewidth = 0.5) +
  labs(title = "Price (Log Transformed)") +
  scale_linetype_manual(name = "", 
                       values = c("Mean" = "dashed", "Mean ± 1 SD" = "dashed")) +
  theme_minimal()

# Plot for Standardized Duration
p2 <- ggplot(Clean_Dataset1, aes(x = duration_scaled)) +
  geom_histogram(aes(y = after_stat(density)), fill = "lightblue", color = "black",
                binwidth = binwidth_duration) +
  geom_density(color = "red") +
  geom_vline(aes(xintercept = mean(duration_scaled), linetype = "Mean"), 
             color = "black", linewidth = 0.5) +
  geom_vline(aes(xintercept = mean(duration_scaled) + sd(duration_scaled), 
                 linetype = "Mean ± 1 SD"),
             color = "yellow", linewidth = 0.5) +
  geom_vline(aes(xintercept = mean(duration_scaled) - sd(duration_scaled)),
             linetype = "dashed", color = "yellow", linewidth = 0.5) +
  labs(title = "Duration (Standardized)") +
  scale_linetype_manual(name = "", 
                       values = c("Mean" = "dashed", "Mean ± 1 SD" = "dashed")) +
  theme_minimal()

# Plot for Standardized Days Left
p3 <- ggplot(Clean_Dataset1, aes(x = days_left_scaled)) +
  geom_histogram(aes(y = after_stat(density)), fill = "lightblue", color = "black",
                binwidth = binwidth_days) +
  geom_density(color = "red") +
  geom_vline(aes(xintercept = mean(days_left_scaled), linetype = "Mean"), 
             color = "black", linewidth = 0.5) +
  geom_vline(aes(xintercept = mean(days_left_scaled) + sd(days_left_scaled), 
                 linetype = "Mean ± 1 SD"),
             color = "yellow", linewidth = 0.5) +
  geom_vline(aes(xintercept = mean(days_left_scaled) - sd(days_left_scaled)),
             linetype = "dashed", color = "yellow", linewidth = 0.5) +
  labs(title = "Days Left (Standardized)") +
  scale_linetype_manual(name = "", 
                       values = c("Mean" = "dashed", "Mean ± 1 SD" = "dashed")) +
  theme_minimal()

# Combine all plots
grid.arrange(p1, p2, p3, ncol = 1, heights = c(1, 1, 1))


Data Exploring

# Define column categories
numeric_columns <- c('price_log', 'duration_scaled', 'days_left_scaled')
categorical_columns <- c('airline', 'source_city', 'destination_city', 
                         'departure_time', 'arrival_time', 'class', 'stops')

# Create a copy of the dataframe
Clean_Dataset1_encoded <- Clean_Dataset1

# Function to encode categorical variables
encode_categorical <- function(x) {
  if(is.character(x) || is.factor(x)) {
    as.numeric(as.factor(x)) - 1  # Subtract 1 to make it 0-based like Python
  } else {
    x
  }
}

# Apply encoding to categorical columns
Clean_Dataset1_encoded[categorical_columns] <- lapply(Clean_Dataset1_encoded[categorical_columns], encode_categorical)

# Combine all columns
all_columns <- c(numeric_columns, categorical_columns)

# Calculate correlation matrix
corr_matrix <- cor(Clean_Dataset1_encoded[all_columns])



# Melt the correlation matrix
melted_corr <- melt(corr_matrix)

# Create the heatmap
ggplot(melted_corr, aes(x = Var1, y = Var2, fill = value)) +
  geom_tile() +
  scale_fill_gradient2(low = "blue", high = "red", 
                       midpoint = 0, limit = c(-1,1), name = "Correlation") +
  geom_text(aes(label = sprintf("%.2f", value)), size = 3) +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1, size = 10),
        axis.text.y = element_text(size = 10),
        plot.title = element_text(size = 10)) +
  labs(title = "Correlation Heatmap (Numeric + Encoded Categorical Variables)",
       x = "", y = "") +
  coord_fixed()

# Group data by airline and count unique prices (number of customers)
number_of_customers <- Clean_Dataset %>%
  group_by(airline) %>%
  summarise(Num_Customers = n_distinct(price)) %>%
  arrange(desc(Num_Customers))

# Merge original data with number of customers for each airline
merged_df <- merge(Clean_Dataset, number_of_customers, 
                   by.x = "airline", by.y = "airline")

# Get number of unique airlines for shape values
n_airlines <- length(unique(merged_df$airline))

# Create scatter plot
ggplot(merged_df, aes(x = price, y = Num_Customers, 
                      color = airline, shape = airline)) +
  geom_point(size = 3, alpha = 0.7) +  # Reduced size from 200 to 3
  
  # Add vertical line at mean price
  geom_vline(xintercept = mean(merged_df$price), 
             color = "red", linetype = "dashed") +
  
  # Customize scales
  scale_color_brewer(palette = "Set2") +
  scale_shape_manual(values = 1:n_airlines) +  # Automatically assign shapes
  
  # Labels and title
  labs(title = "Customer Distribution by Airline",
       x = "Price",
       y = "Number of Customers",
       color = "Airline",
       shape = "Airline") +
  
  # Add commas to y-axis labels
  scale_y_continuous(labels = comma) +
  
  # Theme customization
  theme_classic() +
  theme(
    legend.position = "right",
    legend.background = element_rect(fill = "white", color = "black"),
    legend.title = element_text(face = "bold"),
    plot.title = element_text(hjust = 0.5, face = "bold"),
    axis.title = element_text(face = "bold"),
    legend.key = element_rect(fill = "white")
  )

# Calculate customer count for each airline and class
customer_data <- Clean_Dataset1 %>%
  group_by(airline, class) %>%
  summarise(customer_count = n()) %>%
  ungroup()
## `summarise()` has grouped output by 'airline'. You can override using the
## `.groups` argument.
# Calculate average price for each airline (using raw prices)
price_data <- Clean_Dataset1 %>%
  group_by(airline) %>%
  summarise(avg_price = mean(price)) %>%
  ungroup()

# Add a dummy column to price_data for legend
price_data$type <- "Average Price"

# Create the plot
ggplot() +
  # Bar plot for customer count by airline and class
  geom_bar(data = customer_data, 
           aes(x = airline, y = customer_count, fill = class), 
           stat = "identity", position = "dodge", width = 0.7) +
  # Line plot for average price
  geom_line(data = price_data, 
            aes(x = airline, y = avg_price, linetype = type, color = type), 
            size = 1, group = 1) +
  # Add points for average price
  geom_point(data = price_data, 
             aes(x = airline, y = avg_price, color = type), 
             size = 2) +
  # Scale the y-axis with dual axes
  scale_y_continuous(
    name = "Customer Count (bars)",  # Primary y-axis label
    labels = comma,  # Add commas to primary y-axis
    sec.axis = sec_axis(~ ., name = "Average Price (line)", labels = comma)  # Add commas to secondary y-axis
  ) +
  # Customize colors and line types
  scale_fill_manual(values = c("Economy" = "#4292c6", 
                               "Business" = "#2171b5")) +
  scale_color_manual(values = c("Average Price" = "black")) +
  scale_linetype_manual(values = c("Average Price" = "dashed")) +
  # Customize labels, titles, and themes
  labs(title = "Customer Count and Average Price by Airlines",
       x = "Airline",
       fill = "Class",
       color = "Average Price (line)",
       linetype = "Average Price (line)") +
  theme_minimal() +
  theme(
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "top",
    legend.title = element_text(face = "bold"),
    plot.title = element_text(face = "bold", hjust = 0.5)
  ) +
  # Customize legend appearance
  guides(
    fill = guide_legend(order = 1, title = "Class"),
    color = guide_legend(order = 2, title = "Average Price (line)"),
    linetype = "none"  # Hide separate linetype legend
  )
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.

# Encode categorical columns using one-hot encoding
Clean_Dataset1_encoded <- dummyVars(" ~ .", data = Clean_Dataset1) %>%
  predict(Clean_Dataset1)

# Group by airline, class, and stops, count occurrences
class_stops_counts <- Clean_Dataset1 %>%
  group_by(airline, class, stops) %>%
  summarise(count = n(), .groups = 'drop')

# Calculate average prices and add type for legend
price_data <- Clean_Dataset1 %>%
  group_by(airline) %>%
  summarise(avg_price = mean(price)) %>%
  mutate(type = "Average Price")  # Add type for legend

# Create full combinations of airlines, classes, and stops
full_combinations <- expand.grid(
  airline = unique(Clean_Dataset1$airline),
  class = c("Economy", "Business"),
  stops = c("zero", "one", "two_or_more")
)

# Join with counts and fill missing values
class_stops_counts <- full_combinations %>%
  left_join(class_stops_counts, by = c("airline", "class", "stops")) %>%
  replace_na(list(count = 0)) %>%
  # Create a combined category for legend
  mutate(class_stops = paste(class, "Stops:", stops))

# Create the plot
ggplot() +
  # Add bars for class and stops combinations
  geom_bar(data = class_stops_counts,
           aes(x = airline, 
               y = count, 
               fill = class_stops),
           stat = "identity",
           position = position_dodge(width = 0.9),
           alpha = 0.8,
           width = 0.8) +
  # Add line for average price
  geom_line(data = price_data,
            aes(x = airline, 
                y = avg_price,
                color = type,
                linetype = type),
            size = 1,
            group = 1) +
  geom_point(data = price_data,
             aes(x = airline, 
                 y = avg_price,
                 color = type),
             size = 3) +
  # Customize scales with comma formatting for both y-axes
  scale_y_continuous(
    name = "Purchase Count",
    labels = comma,  # Add commas to the left y-axis
    sec.axis = sec_axis(~ ., name = "Average Price", labels = comma)  # Add commas to the right y-axis
  ) +
  scale_fill_brewer(
    palette = "Set3", 
    name = "Class and Stops"
  ) +
  scale_color_manual(
    values = c("Average Price" = "black")
  ) +
  scale_linetype_manual(
    values = c("Average Price" = "dashed")
  ) +
  # Customize theme
  theme_minimal() +
  theme(
    axis.title.y.left = element_text(color = "blue", size = 12),
    axis.text.y.left = element_text(color = "blue"),
    axis.title.y.right = element_text(color = "red", size = 12),
    axis.text.y.right = element_text(color = "red"),
    axis.text.x = element_text(angle = 45, hjust = 1),
    plot.title = element_text(size = 14, hjust = 0.5),
    legend.position = "right",
    legend.text = element_text(size = 5)  # Smaller legend text for better fit
  ) +
  labs(
    title = "Airline Class, Stops and Price Comparison",
    x = "Airline"
  ) +
  # Customize legends
  guides(
    fill = guide_legend(order = 1, title = "Class and Stops"),
    color = guide_legend(order = 2, title = "Average Price (line)"),
    linetype = "none"  # Hide separate linetype legend
  )

# Prepare source city data
source_plot_data <- Clean_Dataset1 %>%
  group_by(source_city, airline) %>%
  summarise(flights = n(), .groups = 'drop')

source_airline_distribution <- source_plot_data %>%
  group_by(source_city) %>%
  summarise(total = sum(flights), .groups = 'drop') %>%
  arrange(desc(total))

# Prepare destination city data
destination_plot_data <- Clean_Dataset1 %>%
  group_by(destination_city, airline) %>%
  summarise(flights = n(), .groups = 'drop')

destination_airline_distribution <- destination_plot_data %>%
  group_by(destination_city) %>%
  summarise(total = sum(flights), .groups = 'drop') %>%
  arrange(desc(total))

# Define color palette
airline_colors <- c(
  "#2C699A",  # Deep blue
  "#54B4D3",  # Light blue
  "#048BA8",  # Teal
  "#0DB39E",  # Turquoise
  "#16DB93",  # Mint green
  "#83E377"   # Light green
)

# Create source cities plot
source_plot <- ggplot() +
  geom_bar(data = source_plot_data,
           aes(x = reorder(source_city, -flights),
               y = flights,
               fill = airline),
           stat = "identity") +
  geom_line(data = source_airline_distribution,
            aes(x = reorder(source_city, -total),
                y = total,
                color = "Total Flights",
                linetype = "Total Flights",
                group = 1),
            size = 1) +
  geom_point(data = source_airline_distribution,
             aes(x = reorder(source_city, -total),
                 y = total,
                 color = "Total Flights"),
             size = 3) +
  scale_y_continuous(labels = scales::comma,
                     limits = c(0, 65000),
                     breaks = seq(0, 60000, by = 10000)) +
  scale_fill_manual(values = airline_colors) +
  scale_color_manual(values = c("Total Flights" = "#FF6B6B")) +
  scale_linetype_manual(values = c("Total Flights" = "dashed")) +
  labs(title = "Distribution of Airlines by Source City",
       x = "Source City",
       y = "Number of Flights") +
  theme_minimal() +
  theme(
    plot.title = element_text(
      size = 12,
      face = "bold",
      hjust = 0.5,
      margin = margin(t = 10, b = 10),
      color = "#2C3E50"
    ),
    panel.grid.major = element_line(color = "#ECF0F1"),
    panel.grid.minor = element_blank(),
    axis.text.x = element_text(angle = 45, hjust = 1, color = "#34495E"),
    axis.text.y = element_text(color = "#34495E"),
    axis.title = element_text(size = 10, color = "#2C3E50"),
    legend.position = "right",
    legend.title = element_blank(),
    legend.text = element_text(size = 8, color = "#34495E"),
    legend.key.size = unit(0.8, "lines"),
    legend.spacing.y = unit(0.5, "lines"),
    legend.margin = margin(0, 0, 0, 0),
    plot.margin = margin(t = 15, r = 10, b = 10, l = 10),
    plot.background = element_rect(fill = "white", color = NA),
    panel.background = element_rect(fill = "white", color = NA)
  ) +
  guides(
    fill = guide_legend(order = 1, title = "Airlines"),
    color = guide_legend(order = 2, title = "Total Flights (line)"),
    linetype = "none"
  )

# Create destination cities plot
destination_plot <- ggplot() +
  geom_bar(data = destination_plot_data,
           aes(x = reorder(destination_city, -flights),
               y = flights,
               fill = airline),
           stat = "identity") +
  geom_line(data = destination_airline_distribution,
            aes(x = reorder(destination_city, -total),
                y = total,
                color = "Total Flights",
                linetype = "Total Flights",
                group = 1),
            size = 1) +
  geom_point(data = destination_airline_distribution,
             aes(x = reorder(destination_city, -total),
                 y = total,
                 color = "Total Flights"),
             size = 3) +
  scale_y_continuous(labels = scales::comma,
                     limits = c(0, 65000),
                     breaks = seq(0, 60000, by = 10000)) +
  scale_fill_manual(values = airline_colors) +
  scale_color_manual(values = c("Total Flights" = "#4834D4")) +
  scale_linetype_manual(values = c("Total Flights" = "dashed")) +
  labs(title = "Distribution of Airlines by Destination City",
       x = "Destination City",
       y = "Number of Flights") +
  theme_minimal() +
  theme(
    plot.title = element_text(
      size = 12,
      face = "bold",
      hjust = 0.5,
      margin = margin(t = 10, b = 10),
      color = "#2C3E50"
    ),
    panel.grid.major = element_line(color = "#ECF0F1"),
    panel.grid.minor = element_blank(),
    axis.text.x = element_text(angle = 45, hjust = 1, color = "#34495E"),
    axis.text.y = element_text(color = "#34495E"),
    axis.title = element_text(size = 10, color = "#2C3E50"),
    legend.position = "right",
    legend.title = element_blank(),
    legend.text = element_text(size = 8, color = "#34495E"),
    legend.key.size = unit(0.8, "lines"),
    legend.spacing.y = unit(0.5, "lines"),
    legend.margin = margin(0, 0, 0, 0),
    plot.margin = margin(t = 15, r = 10, b = 10, l = 10),
    plot.background = element_rect(fill = "white", color = NA),
    panel.background = element_rect(fill = "white", color = NA)
  ) +
  guides(
    fill = guide_legend(order = 1, title = "Airlines"),
    color = guide_legend(order = 2, title = "Total Flights (line)"),
    linetype = "none"
  )

# Print the plots
print(source_plot)

print(destination_plot)

class_distribution <- Clean_Dataset1 %>%
  count(class) %>%
  mutate(percentage = round(n / sum(n) * 100, 1))

# Create the pie chart
ggplot(class_distribution, aes(x = "", y = percentage, fill = class)) +
  geom_bar(stat = "identity", width = 1) +
  coord_polar(theta = "y") +
  geom_text(aes(label = paste0(class, " (", percentage, "%)")), 
            position = position_stack(vjust = 0.5)) +
  labs(title = "Overall Seat Class Distribution") +
  theme_void() +
  scale_fill_manual(values = c("steelblue", "salmon")) +
  # Add some styling
  theme(
    plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
    legend.position = "right"
  )

# See the distribution data:
print(class_distribution)
## # A tibble: 2 × 3
##   class         n percentage
##   <chr>     <int>      <dbl>
## 1 Business  93128       31.3
## 2 Economy  204792       68.7
# Group by stops, count occurrences
stops_counts <- Clean_Dataset1 %>%
  group_by(stops) %>%
  summarise(count = n())

# Define stop order
stops_order <- c("zero", "one", "two_or_more")

# Reorder factor levels in 'stops'
stops_counts$stops <- fct_rev(fct_reorder(stops_counts$stops, stops_order))

# Create the plot
ggplot(stops_counts, aes(x = stops, y = count)) +
  geom_bar(stat = "identity", fill = "steelblue") +  # Changed color to fill for bars
  labs(
    title = "Purchase Count by Number of Stops",
    x = "Number of Stops",
    y = "Purchase Count"
  ) +
  theme_minimal() +
  theme(
    axis.title.x = element_text(size = 12),
    axis.title.y = element_text(size = 12),
    axis.text.x = element_text(
      angle = 45,
      hjust = 1,
      size = 10
    ),
    axis.text.y = element_text(size = 10),
    plot.title = element_text(size = 16, hjust = 0.5)
  ) +
  scale_x_discrete(labels = rev(stops_order)) +  # Changed to scale_x_discrete and reversed order
  scale_y_continuous(labels = comma)  # Add commas to y-axis numbers

# Prepare transit counts data
transit_counts <- Clean_Dataset1 %>%
  group_by(airline, stops) %>%
  summarise(count = n(), .groups = 'drop')

# Define stop order
stops_order <- c("zero", "one", "two_or_more")

# Convert stops to factor
transit_counts$stops <- factor(transit_counts$stops, levels = stops_order)

# Create the plot with just the bars
ggplot() +
  # Add the bars for transit counts
  geom_bar(data = transit_counts, 
           aes(x = airline, y = count, fill = stops),
           stat = "identity", 
           position = "dodge", 
           width = 0.7) +
  # Primary axis (ticket counts)
  scale_y_continuous(
    name = "Number of Tickets Sold",
    labels = comma  # Add commas to y-axis numbers
  ) +
  # Labels and title
  labs(
    title = "Stops Preferences by Airline",
    x = "Airline",
    fill = "Stops"
  ) +
  # Colors for the bars
  scale_fill_brewer(palette = "Set2") +
  # Theme customization
  theme_minimal() +
  theme(
    axis.title.y = element_text(size = 12, color = "black"),
    axis.text.x = element_text(angle = 45, hjust = 1),
    plot.title = element_text(size = 16, hjust = 0.5),
    legend.title = element_text(size = 8),
    legend.text = element_text(size = 7)
  )

# Create departure times pie chart
departure_plot <- function(Clean_Dataset1) {
  departure_counts <- table(Clean_Dataset1$departure_time)
  
  # Convert counts to percentages
  departure_pct <- round(100 * departure_counts / sum(departure_counts), 1)
  
  # Create color palette similar to matplotlib Paired
  colors <- c("#A6CEE3", "#1F78B4", "#B2DF8A", "#E31A1C", 
              "#FB9A99", "#33A02C", "#FDBF6F", "#FF7F00")
  
  pie(departure_counts,
      labels = paste0(names(departure_counts), "\n(", departure_pct, "%)"),
      col = colors,
      main = "Preferred Departure Times",
      cex.main = 1.4,
      radius = 1,
      init.angle = 90)
}

# Create arrival times pie chart
arrival_plot <- function(Clean_Dataset1) {
  arrival_counts <- table(Clean_Dataset1$arrival_time)
  
  # Convert counts to percentages
  arrival_pct <- round(100 * arrival_counts / sum(arrival_counts), 1)
  
  # Use same color palette
  colors <- c("#A6CEE3", "#1F78B4", "#B2DF8A", "#E31A1C", 
              "#FB9A99", "#33A02C", "#FDBF6F", "#FF7F00")
  
  pie(arrival_counts,
      labels = paste0(names(arrival_counts), "\n(", arrival_pct, "%)"),
      col = colors,
      main = "Preferred Arrival Times",
      cex.main = 1.4,
      radius = 1,
      init.angle = 90)
}

# Create separate plots
# For departure times
par(mfrow=c(1,1))
departure_plot(Clean_Dataset1)

arrival_plot(Clean_Dataset1)

# Create departure time distribution
departure_airline_distribution <- Clean_Dataset1 %>%
  count(departure_time, airline) %>%
  pivot_wider(
    names_from = airline,
    values_from = n,
    values_fill = 0
  )

# Create arrival time distribution
arrival_airline_distribution <- Clean_Dataset1 %>%
  count(arrival_time, airline) %>%
  pivot_wider(
    names_from = airline,
    values_from = n,
    values_fill = 0
  )

# Calculate totals
departure_totals <- departure_airline_distribution %>%
  pivot_longer(cols = -departure_time, values_to = "count") %>%
  group_by(departure_time) %>%
  summarise(total = sum(count))

arrival_totals <- arrival_airline_distribution %>%
  pivot_longer(cols = -arrival_time, values_to = "count") %>%
  group_by(arrival_time) %>%
  summarise(total = sum(count))

# Define a harmonious color palette
airline_colors <- c(
  "#2E86AB",  # Steel Blue
  "#A23B72",  # Raspberry
  "#F18F01",  # Orange
  "#44CF6C",  # Emerald Green
  "#6B4E71",  # Deep Purple
  "#C73E1D"   # Coral Red
)

# Common theme settings
common_theme <- theme_minimal() +
  theme(
    plot.title = element_text(
      size = 14, 
      hjust = 0.5, 
      face = "bold",
      color = "#2C3E50",
      margin = margin(b = 15)
    ),
    axis.title = element_text(
      size = 12,
      color = "#2C3E50"
    ),
    axis.text = element_text(
      color = "#2C3E50"
    ),
    axis.text.x = element_text(
      angle = 45, 
      hjust = 1,
      margin = margin(t = 5)
    ),
    legend.title = element_text(
      size = 10,
      color = "#2C3E50"
    ),
    legend.text = element_text(
      size = 9,
      color = "#2C3E50"
    ),
    panel.grid.major = element_line(
      color = "gray90"
    ),
    panel.grid.minor = element_blank(),
    plot.margin = margin(t = 20, r = 20, b = 20, l = 20)
  )

# Create departure plot
departure_plot <- ggplot() +
  geom_bar(data = departure_airline_distribution %>%
             pivot_longer(cols = -departure_time, 
                          names_to = "airline", 
                          values_to = "count"),
           aes(x = departure_time, y = count, fill = airline),
           stat = "identity",
           position = "stack") +
  geom_line(data = departure_totals,
            aes(x = departure_time, y = total, group = 1, linetype = "Total Flights"),
            color = "#008080",
            size = 1.2) +
  geom_point(data = departure_totals,
             aes(x = departure_time, y = total, shape = "Total Flights"),
             color = "#008080",
             size = 3.5) +
  labs(
    title = "Departure Time Distribution by Airline",
    x = "Departure Time",
    y = "Number of Flights",
    fill = "Airline",
    linetype = "Total",
    shape = "Total"
  ) +
  scale_fill_manual(values = airline_colors) +
  scale_linetype_manual(values = c("Total Flights" = "dashed")) +
  scale_shape_manual(values = c("Total Flights" = 19)) +
  scale_y_continuous(labels = scales::comma) +
  common_theme +
  guides(
    linetype = guide_legend(title = "Total Flights (line)"),
    shape = guide_legend(title = "Total Flights (line)")
  )

# Create arrival plot
arrival_plot <- ggplot() +
  geom_bar(data = arrival_airline_distribution %>%
             pivot_longer(cols = -arrival_time, 
                          names_to = "airline", 
                          values_to = "count"),
           aes(x = arrival_time, y = count, fill = airline),
           stat = "identity",
           position = "stack") +
  geom_line(data = arrival_totals,
            aes(x = arrival_time, y = total, group = 1, linetype = "Total Flights"),
            color = "#2C3E50",
            size = 1.2) +
  geom_point(data = arrival_totals,
             aes(x = arrival_time, y = total, shape = "Total Flights"),
             color = "#2C3E50",
             size = 3.5) +
  labs(
    title = "Arrival Time Distribution by Airline",
    x = "Arrival Time",
    y = "Number of Flights",
    fill = "Airline",
    linetype = "Total",
    shape = "Total"
  ) +
  scale_fill_manual(values = airline_colors) +
  scale_linetype_manual(values = c("Total Flights" = "dashed")) +
  scale_shape_manual(values = c("Total Flights" = 19)) +
  scale_y_continuous(labels = scales::comma) +
  common_theme +
  guides(
    linetype = guide_legend(title = "Total Flights (line)"),
    shape = guide_legend(title = "Total Flights (line)")
  )

# Print plots
print(departure_plot)

print(arrival_plot)

# Extract unique cities and their coordinates from Clean_Dataset1
city_coords <- Clean_Dataset1 %>%
  select(source_city, destination_city) %>%
  gather(type, city) %>%
  distinct(city) %>%
  mutate(
    latitude = case_when(
      city == "Mumbai" ~ 19.0760,
      city == "Delhi" ~ 28.6139,
      city == "Bangalore" ~ 12.9716,
      city == "Kolkata" ~ 22.5726,
      city == "Hyderabad" ~ 17.3850,
      city == "Chennai" ~ 13.0827
    ),
    longitude = case_when(
      city == "Mumbai" ~ 72.8777,
      city == "Delhi" ~ 77.2090,
      city == "Bangalore" ~ 77.5946,
      city == "Kolkata" ~ 88.3639,
      city == "Hyderabad" ~ 78.4867,
      city == "Chennai" ~ 80.2707
    )
  )

# Create source cities map
source_map <- leaflet() %>%
  addTiles() %>%
  setView(lng = 78.9629, lat = 20.5937, zoom = 5) %>%
  addCircleMarkers(
    data = city_coords,
    lng = ~longitude,
    lat = ~latitude,
    radius = 8,
    color = "red",
    fillColor = "red",
    fillOpacity = 0.6,
    label = ~city,
    popup = ~paste("<b>Source:</b>", city)
  )

# Create destination cities map
dest_map <- leaflet() %>%
  addTiles() %>%
  setView(lng = 78.9629, lat = 20.5937, zoom = 5) %>%
  addCircleMarkers(
    data = city_coords,
    lng = ~longitude,
    lat = ~latitude,
    radius = 8,
    color = "blue",
    fillColor = "blue",
    fillOpacity = 0.6,
    label = ~city,
    popup = ~paste("<b>Destination:</b>", city)
  )

# Display maps one after another
source_map
dest_map
# Create a contingency table
source_destination_distribution <- table(Clean_Dataset1$source_city, Clean_Dataset1$destination_city)

# Convert the table to a dataframe for ggplot
Clean_Dataset1_heatmap <- as.data.frame(source_destination_distribution)

# Create the heatmap
ggplot(Clean_Dataset1_heatmap, aes(x = Var1, y = Var2, fill = Freq)) +
  geom_tile() +
  geom_text(aes(label = Freq), color = "black", size = 3) +
  scale_fill_gradient(low = "grey", high = "red") +
  labs(
    title = "Source to Destination Distribution",
    x = "Destination City",
    y = "Source City"
  ) +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

# Calculate counts by destination and class
class_distribution <- Clean_Dataset1 %>%
  group_by(destination_city, class) %>%
  summarise(count = n(), .groups = 'drop') %>%
  ungroup()

# Split data into Economy and Business
economy_data <- class_distribution %>% 
  filter(class == "Economy")
business_data <- class_distribution %>% 
  filter(class == "Business")

# Create the plot
ggplot() +
  # Economy connecting line
  geom_line(data = economy_data, 
            aes(x = destination_city, y = count, group = 1, color = "Economy"),
            size = 1) +
  
  # Business connecting line
  geom_line(data = business_data, 
            aes(x = destination_city, y = count, group = 1, color = "Business"),
            size = 1, linetype = "dashed") +
  
  # Economy data points
  geom_point(data = economy_data, 
             aes(x = destination_city, y = count, color = "Economy"),
             size = 3, shape = 16) +
  
  # Business data points
  geom_point(data = business_data, 
             aes(x = destination_city, y = count, color = "Business"),
             size = 3, shape = 15) +
  
  # Customize colors and legend
  scale_color_manual(
    values = c("Economy" = "steelblue", 
               "Business" = "salmon"),
    name = "Seat Class"
  ) +
  
  # Add labels and customize theme
  labs(
    title = "Seat Class Distribution by Destination City",
    x = "Destination City",
    y = "Number of Tickets Sold"
  ) +
  
  scale_y_continuous(labels = comma) +  # Add commas to y-axis labels
  
  theme_minimal() +
  theme(
    axis.text.x = element_text(angle = 45, hjust = 1),
    panel.grid.minor = element_blank(),
    legend.position = "top",
    plot.title = element_text(hjust = 0.4, face = "bold"),
    panel.grid.major.x = element_blank(),
    plot.margin = margin(t = 20, r = 20, b = 20, l = 20)
  )

# Create duration categories
breaks <- c(0, 7, 11.17, 16, max(Clean_Dataset1$duration))
labels <- c('Short', 'Medium', 'Long', 'Ultra Long')

# Add duration category to the dataframe
Clean_Dataset1 <- Clean_Dataset1 %>%
  mutate(duration_category = cut(duration, 
                                 breaks = breaks, 
                                 labels = labels, 
                                 include.lowest = TRUE,
                                 right = FALSE))

# Calculate class distribution by duration
class_distribution_by_duration <- Clean_Dataset1 %>%
  group_by(duration_category, class) %>%
  summarise(count = n(), .groups = 'drop') %>%
  ungroup()

# Create separate dataframe for trend lines
trend_data <- data.frame(
  duration_category = rep(unique(class_distribution_by_duration$duration_category), 2),
  trend = c(rep("Economy Trend", 4), rep("Business Trend", 4)),
  count = c(
    filter(class_distribution_by_duration, class == "Economy")$count,
    filter(class_distribution_by_duration, class == "Business")$count
  )
)

# Create the plot
ggplot() +
  # Add bars
  geom_bar(data = class_distribution_by_duration,
           aes(x = duration_category, y = count, fill = class),
           stat = "identity", 
           position = "dodge", 
           alpha = 0.7) +
  
  # Add trend lines
  geom_line(data = trend_data,
            aes(x = duration_category, y = count, 
                group = trend, color = trend),
            linetype = "dashed") +
  geom_point(data = trend_data,
             aes(x = duration_category, y = count, 
                 color = trend),
             size = 3) +
  
  # Customize colors
  scale_fill_manual("Class",
                    values = c("Economy" = "#4292c6", 
                               "Business" = "darkblue")) +
  scale_color_manual("Trend Lines",
                     values = c("Economy Trend" = "blue",
                                "Business Trend" = "orange")) +
  
  # Add commas to y-axis labels
  scale_y_continuous(labels = comma) +
  
  # Customize labels and theme
  labs(title = "Class Preference by Flight Duration with Trends",
       x = "Duration Category",
       y = "Number of Tickets Sold") +
  
  theme_minimal() +
  theme(
    plot.title = element_text(size = 14, face = "bold", hjust = 0.5),
    axis.title = element_text(size = 12),
    axis.text = element_text(size = 10),
    legend.title = element_text(size = 10, face = "bold"),
    legend.position = "right",
    panel.grid.minor = element_blank()
  )

# Create duration distribution data frame
duration_distribution <- Clean_Dataset1 %>%
  group_by(destination_city, duration_category) %>%
  summarise(count = n(), .groups = 'drop')

# Create the plot
ggplot(duration_distribution, 
       aes(x = destination_city, 
           y = count, 
           fill = duration_category)) +
  # Add bars
  geom_bar(stat = "identity",
           position = "dodge",
           width = 0.8) +
  
  # Use muted color palette (similar to seaborn's muted)
  scale_fill_manual(values = c(
    "Short" = "#4878CF",
    "Medium" = "#6ACC65",
    "Long" = "#D65F5F",
    "Ultra Long" = "#B47CC7"
  )) +
  
  # Customize labels and theme
  labs(title = "Flight Duration Category Distribution by Destination City",
       x = "Destination City",
       y = "Number of Tickets Sold",
       fill = "Duration Category") +
  
  theme_minimal() +
  theme(
    # Title formatting
    plot.title = element_text(size = 12, 
                              face = "bold",
                              hjust = 0.5,
                              margin = margin(b = 10)),
    
    # Axis formatting
    axis.title.x = element_text(size = 12,
                                margin = margin(t = 10)),
    axis.title.y = element_text(size = 12,
                                margin = margin(r = 10)),
    axis.text.x = element_text(angle = 45,
                               hjust = 1,
                               vjust = 1),
    
    # Legend formatting
    legend.title = element_text(size = 8,
                                face = "bold"),
    legend.position = "right",
    
    # Grid lines
    panel.grid.minor = element_blank(),
    panel.grid.major.x = element_blank()
  ) +
  
  # Add commas to y-axis numbers
  scale_y_continuous(labels = comma, 
                     expand = expansion(mult = c(0, 0.1))) +
  
  # Add some spacing around the plot
  coord_cartesian(clip = "off")

# Create price ranges - making sure max price is included
price_bins <- c(0, 5000, 10000, 25000, 50000, ceiling(max(Clean_Dataset1$price)))
price_labels <- c('0-5k', '5k-10k', '10k-25k', '25k-50k', '50k+')

# Add price range category and ensure all combinations exist
class_price_distribution <- Clean_Dataset1 %>%
  # Create price range category
  mutate(price_range = cut(price, 
                           breaks = price_bins, 
                           labels = price_labels, 
                           right = FALSE,
                           include.lowest = TRUE)) %>%
  # Count occurrences
  group_by(price_range, class) %>%
  summarise(count = n(), .groups = 'drop') %>%
  # Make sure we have all combinations
  complete(price_range = factor(price_labels, levels = price_labels), 
           class = unique(Clean_Dataset1$class), 
           fill = list(count = 0)) %>%
  # Create wide format
  pivot_wider(names_from = class,
              values_from = count,
              values_fill = 0) %>%
  # Calculate Business Ratio
  mutate(Business_Ratio = Business / (Business + Economy))

# Prepare data for plotting
plot_data <- class_price_distribution %>%
  pivot_longer(cols = c(Economy, Business),
               names_to = "Class",
               values_to = "Count")

# Create the plot
p <- ggplot() +
  # Add stacked bars
  geom_bar(data = plot_data,
           aes(x = price_range, y = Count, fill = Class),
           stat = "identity",
           position = "stack",
           alpha = 0.7) +
  # Add line for business ratio
  geom_line(data = class_price_distribution,
            aes(x = price_range, y = Business_Ratio * max(plot_data$Count),
                group = 1,
                color = "Business Ratio"),
            linetype = "dashed",
            linewidth = 1) +
  geom_point(data = class_price_distribution,
             aes(x = price_range, y = Business_Ratio * max(plot_data$Count),
                 color = "Business Ratio"),
             size = 3) +
  # Set colors
  scale_fill_manual(values = c("Economy" = "skyblue", "Business" = "salmon")) +
  scale_color_manual(values = c("Business Ratio" = "red")) +
  # Add second y-axis for ratio
  scale_y_continuous(
    name = "Number of Tickets Sold",
    labels = scales::comma,  # Format primary y-axis with commas
    sec.axis = sec_axis(~ . / max(plot_data$Count),
                        name = "Business Selection Ratio")
  ) +
  # Customize theme
  theme_minimal() +
  theme(
    axis.title.y.left = element_text(color = "blue", size = 12),
    axis.text.y.left = element_text(color = "blue"),
    axis.title.y.right = element_text(color = "red", size = 12),
    axis.text.y.right = element_text(color = "red"),
    axis.title.x = element_text(size = 12),
    plot.title = element_text(size = 16, hjust = 0.5),
    legend.position = "top",
    legend.title = element_text(size = 10),
    legend.box = "horizontal"
  ) +
  # Labels
  labs(
    title = "Class Selection by Price Range",
    x = "Price Range",
    fill = "Class"
  )

# Print the plot
print(p)

# Define time order
time_order <- c('Early_Morning', 'Morning', 'Afternoon', 'Evening', 'Late_Night', 'Night')

# Convert to factor with ordered levels
Clean_Dataset1$departure_time <- factor(Clean_Dataset1$departure_time, levels = time_order)
Clean_Dataset1$arrival_time <- factor(Clean_Dataset1$arrival_time, levels = time_order)

# Calculate average prices
departure_avg <- Clean_Dataset1 %>%
  group_by(departure_time) %>%
  summarise(avg_price = mean(price))

arrival_avg <- Clean_Dataset1 %>%
  group_by(arrival_time) %>%
  summarise(avg_price = mean(price))

# Departure time plot
departure_plot <- ggplot(departure_avg, aes(x = departure_time, y = avg_price)) +
  geom_bar(stat = "identity", fill = "skyblue", alpha = 0.7) +
  geom_line(aes(group = 1, color = "Trend (Line)"), size = 1) +
  geom_point(aes(color = "Trend (Line)"), size = 3) +
  scale_color_manual(values = c("Trend (Line)" = "blue")) +
  labs(title = "Average Ticket Price by Departure Time",
       x = "Departure Time",
       y = "Average Price",
       color = "") +
  theme_bw() +
  theme(
    plot.title = element_text(size = 14, hjust = 0.5),
    axis.title = element_text(size = 12),
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "top",
    panel.grid.major = element_line(color = "grey90"),
    panel.grid.minor = element_line(color = "grey95")
  ) +
  scale_y_continuous(labels = comma)

# Arrival time plot
arrival_plot <- ggplot(arrival_avg, aes(x = arrival_time, y = avg_price)) +
  geom_bar(stat = "identity", fill = "lightcoral", alpha = 0.7) +
  geom_line(aes(group = 1, color = "Trend (Line)"), size = 1) +
  geom_point(aes(color = "Trend (Line)"), size = 3) +
  scale_color_manual(values = c("Trend (Line)" = "red")) +
  labs(title = "Average Ticket Price by Arrival Time",
       x = "Arrival Time",
       y = "Average Price",
       color = "") +
  theme_bw() +
  theme(
    plot.title = element_text(size = 14, hjust = 0.5),
    axis.title = element_text(size = 12),
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "top",
    panel.grid.major = element_line(color = "grey90"),
    panel.grid.minor = element_line(color = "grey95")
  ) +
  scale_y_continuous(labels = comma)

# Display plots
print(departure_plot)

print(arrival_plot)

# Create the grouped data
area_data <- Clean_Dataset1 %>%
  count(days_left, airline) %>%
  pivot_wider(names_from = airline, 
              values_from = n, 
              values_fill = 0) %>%
  arrange(days_left)

# Convert to long format for ggplot
area_data_long <- area_data %>%
  pivot_longer(cols = -days_left,
               names_to = "airline",
               values_to = "count")

# Create the area plot
ggplot(area_data_long, aes(x = days_left, y = count, fill = airline)) +
  geom_area(alpha = 0.8, position = "stack") +
  scale_fill_brewer(palette = "Set2") +
  labs(title = "Number of Purchases by Days Left Before Departure",
       x = "Days Left",
       y = "Number of Purchases",
       fill = "Airline") +
  theme_bw() +
  theme(
    plot.title = element_text(size = 14, hjust = 0.5),
    axis.title = element_text(size = 12),
    axis.text = element_text(size = 10),
    legend.title = element_text(size = 10),
    legend.position = "right",
    panel.grid.major = element_line(color = "grey90"),
    panel.grid.minor = element_line(color = "grey95")
  ) +
  scale_y_continuous(labels = comma) +
  scale_x_continuous(expand = c(0, 0))  # Remove padding on x-axis

# Create the grouped data
class_days_data <- Clean_Dataset1 %>%
  count(days_left, class) %>%
  pivot_wider(names_from = class, 
              values_from = n, 
              values_fill = 0) %>%
  arrange(days_left)

# Convert to long format for ggplot
class_days_long <- class_days_data %>%
  pivot_longer(cols = -days_left,
               names_to = "class",
               values_to = "count")

# Create the line plot
ggplot(class_days_long, aes(x = days_left, y = count, color = class)) +
  geom_line(linewidth = 1) +
  geom_point(size = 2) +
  scale_color_brewer(palette = "Set2") +
  labs(title = "Purchases by Days Left Before Departure (Class-wise)",
       x = "Days Left Before Departure",
       y = "Number of Purchases",
       color = "Class") +
  theme_bw() +
  theme(
    plot.title = element_text(size = 14, hjust = 0.5),
    axis.title = element_text(size = 12),
    axis.text = element_text(size = 10),
    legend.title = element_text(size = 10),
    legend.position = "right",
    panel.grid.major.y = element_line(linetype = "dashed", color = "grey70"),
    panel.grid.major.x = element_line(color = "grey90"),
    panel.grid.minor = element_blank()
  ) +
  scale_y_continuous(labels = comma) +
  scale_x_continuous(
    expand = expansion(mult = 0.02)  # Small padding on x-axis
  )

# Step 1: Summarize the data
summary_by_class <- Clean_Dataset1 %>%
  group_by(days_left, class) %>%
  summarize(
    purchase_count = n(),
    avg_price = mean(price, na.rm = TRUE),
    .groups = "drop"
  )

# Step 2: Define a plotting function
plot_by_class <- function(class_data, class_name) {
  # Max values for scaling
  max_purchase <- max(class_data$purchase_count, na.rm = TRUE)
  max_price <- max(class_data$avg_price, na.rm = TRUE)
  
  ggplot(class_data, aes(x = days_left)) +
    # Bars for purchase count
    geom_bar(aes(y = purchase_count), stat = "identity", fill = "skyblue", alpha = 0.7) +
    # Red dashed line for average price
    geom_line(aes(y = avg_price * max_purchase / max_price), color = "red", linetype = "dashed", size = 1) +
    # Red points for average price
    geom_point(aes(y = avg_price * max_purchase / max_price), color = "red", size = 2) +
    # Secondary axis for average price
    scale_y_continuous(
      name = "Number of Purchases",
      labels = comma,
      sec.axis = sec_axis(
        trans = ~ . * max_price / max_purchase,
        name = "Average Price",
        labels = comma
      )
    ) +
    labs(
      title = paste(class_name, "Class"),
      x = "Days Left",
      y = "Count / Price (Scaled)"
    ) +
    theme_minimal() +
    theme(
      plot.title = element_text(size = 16, face = "bold", hjust = 0.5),
      axis.title.y.left = element_text(color = "blue", size = 12),
      axis.text.y.left = element_text(color = "blue"),  # First y-axis numbers in blue
      axis.title.y.right = element_text(color = "red", size = 12),
      axis.text.y.right = element_text(color = "red"),  # Secondary y-axis numbers in red
      legend.position = "none",
      panel.grid.major = element_line(color = "gray90")
    )
}

# Step 3: Generate plots for each class
# Economy Class
economy_data <- summary_by_class %>% filter(class == "Economy")
economy_plot <- plot_by_class(economy_data, "Economy")
## Warning: The `trans` argument of `sec_axis()` is deprecated as of ggplot2 3.5.0.
## ℹ Please use the `transform` argument instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
print(economy_plot)

# Business Class
business_data <- summary_by_class %>% filter(class == "Business")
business_plot <- plot_by_class(business_data, "Business")
print(business_plot)

# Create departure time distribution
departure_distribution <- Clean_Dataset1 %>%
  group_by(departure_time, class) %>%
  summarise(count = n(), .groups = 'drop') %>%
  arrange(departure_time)

# Create arrival time distribution
arrival_distribution <- Clean_Dataset1 %>%
  group_by(arrival_time, class) %>%
  summarise(count = n(), .groups = 'drop') %>%
  arrange(arrival_time)

# Departure time plot
departure_plot <- ggplot(departure_distribution, 
                         aes(x = departure_time, 
                             y = count, 
                             fill = class)) +
  geom_area(aes(color = class), alpha = 0.2, position = 'identity') +
  geom_line(aes(group = class, color = class), size = 0.8) +
  scale_y_continuous(labels = comma, expand = expansion(mult = c(0, 0.1))) +
  scale_color_manual(values = c("Business" = "#1f77b4", "Economy" = "#ff7f0e")) +
  scale_fill_manual(values = c("Business" = "#1f77b4", "Economy" = "#ff7f0e")) +
  labs(title = "Seat Class Distribution by Departure Time",
       x = "Departure Time",
       y = "Number of Tickets Sold",
       color = "Class",
       fill = "Class") +
  theme_minimal() +
  theme(
    plot.title = element_text(size = 16, hjust = 0.5),
    axis.title = element_text(size = 12),
    axis.text.x = element_text(angle = 45, hjust = 1),
    panel.grid.major = element_line(color = "gray90"),
    panel.grid.minor = element_blank(),
    legend.position = "top",
    legend.background = element_rect(fill = "white", color = NA),
    plot.margin = margin(t = 20, r = 20, b = 20, l = 20)
  )

# Arrival time plot
arrival_plot <- ggplot(arrival_distribution, 
                       aes(x = arrival_time, 
                           y = count, 
                           fill = class)) +
  geom_area(aes(color = class), alpha = 0.2, position = 'identity') +
  geom_line(aes(group = class, color = class), size = 0.8) +
  scale_y_continuous(labels = comma, expand = expansion(mult = c(0, 0.1))) +
  scale_color_manual(values = c("Business" = "#1f77b4", "Economy" = "#ff7f0e")) +
  scale_fill_manual(values = c("Business" = "#1f77b4", "Economy" = "#ff7f0e")) +
  labs(title = "Seat Class Distribution by Arrival Time",
       x = "Arrival Time",
       y = "Number of Tickets Sold",
       color = "Class",
       fill = "Class") +
  theme_minimal() +
  theme(
    plot.title = element_text(size = 16, hjust = 0.5),
    axis.title = element_text(size = 12),
    axis.text.x = element_text(angle = 45, hjust = 1),
    panel.grid.major = element_line(color = "gray90"),
    panel.grid.minor = element_blank(),
    legend.position = "top",
    legend.background = element_rect(fill = "white", color = NA),
    plot.margin = margin(t = 20, r = 20, b = 20, l = 20)
  )


# Display departure plot
print(departure_plot)

print(arrival_plot)


Clusters

# K-means 

# Sample and prepare features
set.seed(42)
df_sampled <- Clean_Dataset1 %>% sample_n(30000)
continuous_features <- c('duration_scaled', 'days_left_scaled', 'price_log')
categorical_features <- c('airline', 'source_city', 'departure_time', 'stops',
                          'arrival_time', 'destination_city', 'class')

# One-hot encoding with explicit class encoding
encoded_categorical <- model.matrix(~ . - 1, data = df_sampled[categorical_features])
encoded_categorical_df <- as.data.frame(encoded_categorical)
encoded_categorical_df$classBusiness <- as.numeric(df_sampled$class == "Business")
encoded_categorical_df$classEconomy <- as.numeric(df_sampled$class == "Economy")
encoded_categorical_df <- encoded_categorical_df[!grepl("^classclass", names(encoded_categorical_df))]
data_encoded <- cbind(df_sampled[continuous_features], encoded_categorical_df)

# Impute and scale
preProcess_missingdata <- preProcess(data_encoded, method = "medianImpute")
data_encoded_imputed <- predict(preProcess_missingdata, data_encoded)
preProcess_scale <- preProcess(data_encoded_imputed, method = c("center", "scale"))
X_scaled <- predict(preProcess_scale, data_encoded_imputed)

# PCA
pca_result <- prcomp(X_scaled, center = TRUE, scale. = TRUE)
X_pca <- pca_result$x[, 1:2]
X_pca_df <- data.frame(X_pca)

# K-means Clustering
n_clusters <- 6
set.seed(42)
kmeans_result <- kmeans(X_pca, centers = n_clusters, nstart = 20)
kmeans_clusters <- kmeans_result$cluster

# Silhouette score
dist_matrix <- dist(X_pca)
silhouette_kmeans <- mean(silhouette(kmeans_clusters, dist_matrix)[, "sil_width"])
cat(sprintf("\nK-means Silhouette Score: %.4f\n", silhouette_kmeans))
## 
## K-means Silhouette Score: 0.4576
# Plot
plot_kmeans <- ggplot(X_pca_df, aes(x = PC1, y = PC2, color = factor(kmeans_clusters))) +
 geom_point(alpha = 0.6, size = 2) +
 scale_color_viridis_d() +
 labs(title = "K-means Clustering (6 clusters)",
      x = "PCA Component 1",
      y = "PCA Component 2",
      color = "Cluster") +
 theme_minimal() +
 theme(plot.title = element_text(size = 14, face = "bold"))
print(plot_kmeans)

# Cluster sizes
cat("\nK-means cluster sizes:\n")
## 
## K-means cluster sizes:
print(table(kmeans_clusters))
## kmeans_clusters
##    1    2    3    4    5    6 
## 5131 6112 5037 4442 4561 4717
# Cluster characteristics
cluster_stats <- df_sampled[continuous_features] %>%
 mutate(cluster = kmeans_clusters) %>%
 group_by(cluster) %>%
 summarise(across(everything(), mean)) %>%
 round(3)
cat("\nK-means Cluster Characteristics:\n")
## 
## K-means Cluster Characteristics:
print(cluster_stats)
## # A tibble: 6 × 4
##   cluster duration_scaled days_left_scaled price_log
##     <dbl>           <dbl>            <dbl>     <dbl>
## 1       1          -0.979            0.535      8.37
## 2       2           0.962            0.502      8.83
## 3       3          -0.155            0.504      8.93
## 4       4          -0.52             0.558      8.41
## 5       5          -0.396            0.527     10.8 
## 6       6           0.822            0.512     10.9
# Calculate centroids
centroids_pca <- data.frame(
 Cluster = 1:n_clusters,
 PC1 = kmeans_result$centers[, 1],
 PC2 = kmeans_result$centers[, 2]
)

# Transform centroids
rotation <- pca_result$rotation[, 1:2]
center <- pca_result$center
scale <- pca_result$scale
centroids_original <- t(rotation %*% t(kmeans_result$centers))
centroids_original <- scale(centroids_original, center = -center, scale = 1/scale)
centroids_original_df <- as.data.frame(centroids_original)
colnames(centroids_original_df) <- names(center)
centroids_original_df$Cluster <- 1:n_clusters

# Print results
cat("\nCluster Centroids (PCA Space):\n")
## 
## Cluster Centroids (PCA Space):
print(round(centroids_pca, 3))
##   Cluster    PC1    PC2
## 1       1 -2.163  1.217
## 2       2 -0.295 -1.788
## 3       3 -0.463  0.472
## 4       4 -1.849 -0.414
## 5       5  2.293  1.556
## 6       6  2.754 -0.625
cat("\nCluster Centroids (Original Feature Space):\n")
## 
## Cluster Centroids (Original Feature Space):
print(round(centroids_original_df, 3))
##   duration_scaled days_left_scaled price_log airlineAir_India airlineAirAsia
## 1          -1.060            0.112    -0.941           -0.620          0.197
## 2           0.736           -0.002    -0.337            0.554          0.215
## 3          -0.327            0.026    -0.179           -0.202          0.022
## 4          -0.248            0.082    -0.960           -0.055          0.309
## 5          -0.176           -0.093     1.302           -0.272         -0.484
## 6           0.919           -0.134     1.297            0.488         -0.340
##   airlineGO_FIRST airlineIndigo airlineSpiceJet airlineVistara
## 1           0.331         0.820           0.065         -0.319
## 2           0.030        -0.329           0.160         -0.430
## 3           0.072         0.223          -0.002         -0.026
## 4           0.271         0.373           0.168         -0.560
## 5          -0.328        -0.227          -0.289          0.900
## 6          -0.414        -0.836          -0.154          0.588
##   source_cityChennai source_cityDelhi source_cityHyderabad source_cityKolkata
## 1              0.129           -0.076                0.022              0.011
## 2             -0.232            0.205               -0.046             -0.019
## 3              0.054           -0.039                0.010              0.005
## 4             -0.076            0.096               -0.018             -0.007
## 5              0.227           -0.233                0.048              0.019
## 6             -0.046           -0.006               -0.005             -0.003
##   source_cityMumbai departure_timeMorning departure_timeAfternoon
## 1            -0.068                 0.304                   0.274
## 2             0.068                -0.476                  -0.188
## 3            -0.023                 0.121                   0.083
## 4            -0.001                -0.125                   0.067
## 5            -0.039                 0.431                   0.037
## 6             0.051                -0.143                  -0.243
##   departure_timeEvening departure_timeLate_Night departure_timeNight
## 1                -0.317                    0.099              -0.319
## 2                 0.409                    0.001               0.267
## 3                -0.117                    0.023              -0.102
## 4                 0.066                    0.075              -0.041
## 5                -0.324                   -0.087              -0.115
## 6                 0.189                   -0.120               0.259
##   stopstwo_or_more stopszero arrival_timeMorning arrival_timeAfternoon
## 1           -0.085     0.636              -0.546                 0.151
## 2            0.246    -0.406               0.572                -0.099
## 3           -0.046     0.189              -0.187                 0.045
## 4            0.119     0.177               0.014                 0.040
## 5           -0.285     0.043              -0.363                 0.014
## 6           -0.014    -0.577               0.390                -0.136
##   arrival_timeEvening arrival_timeLate_Night arrival_timeNight
## 1               0.027                  0.228             0.256
## 2              -0.185                  0.049            -0.338
## 3               0.026                  0.047             0.095
## 4              -0.117                  0.208            -0.059
## 5               0.246                 -0.267             0.272
## 6               0.055                 -0.299            -0.149
##   destination_cityChennai destination_cityDelhi destination_cityHyderabad
## 1                  -0.211                 0.331                    -0.083
## 2                   0.221                -0.332                     0.063
## 3                  -0.072                 0.112                    -0.026
## 4                   0.005                 0.002                    -0.016
## 5                  -0.140                 0.198                    -0.020
## 6                   0.151                -0.243                     0.071
##   destination_cityKolkata destination_cityMumbai classEconomy classBusiness
## 1                  -0.204                  0.092        0.825        -0.825
## 2                   0.255                 -0.135        0.438        -0.438
## 3                  -0.074                  0.036        0.141        -0.141
## 4                   0.036                 -0.031        0.947        -0.947
## 5                  -0.195                  0.118       -1.348         1.348
## 6                   0.126                 -0.047       -1.203         1.203
##   Cluster
## 1       1
## 2       2
## 3       3
## 4       4
## 5       5
## 6       6

# Hierarchical

# Sample and prepare features
set.seed(42)
df_sampled <- Clean_Dataset1 %>% sample_n(30000)
continuous_features <- c('duration_scaled', 'days_left_scaled', 'price_log')
categorical_features <- c('airline', 'source_city', 'departure_time', 'stops',
                          'arrival_time', 'destination_city', 'class')

# One-hot encoding with explicit class encoding
encoded_categorical <- model.matrix(~ . - 1, data = df_sampled[categorical_features])
encoded_categorical_df <- as.data.frame(encoded_categorical)
encoded_categorical_df$classBusiness <- as.numeric(df_sampled$class == "Business")
encoded_categorical_df$classEconomy <- as.numeric(df_sampled$class == "Economy")
encoded_categorical_df <- encoded_categorical_df[!grepl("^classclass", names(encoded_categorical_df))]
data_encoded <- cbind(df_sampled[continuous_features], encoded_categorical_df)

# Impute and scale
preProcess_missingdata <- preProcess(data_encoded, method = "medianImpute")
data_encoded_imputed <- predict(preProcess_missingdata, data_encoded)
preProcess_scale <- preProcess(data_encoded_imputed, method = c("center", "scale"))
X_scaled <- predict(preProcess_scale, data_encoded_imputed)

# Perform PCA
pca_result <- prcomp(X_scaled, center = TRUE, scale. = TRUE)
X_pca <- pca_result$x[, 1:2]
X_pca_df <- data.frame(X_pca)

# Hierarchical Clustering
dist_matrix <- dist(X_pca)
hclust_result <- hclust(dist_matrix, method = "ward.D2")
n_clusters <- 6
hclust_clusters <- cutree(hclust_result, k = n_clusters)

# Calculate silhouette score
silhouette_hclust <- mean(silhouette(hclust_clusters, dist_matrix)[, "sil_width"])
cat(sprintf("\nHierarchical Clustering Silhouette Score: %.4f\n", silhouette_hclust))
## 
## Hierarchical Clustering Silhouette Score: 0.4217
# Create scatter plot
plot_hclust <- ggplot(X_pca_df, aes(x = PC1, y = PC2, color = factor(hclust_clusters))) +
 geom_point(alpha = 0.6, size = 2) +
 scale_color_viridis_d() +
 labs(title = "Hierarchical Clustering (6 clusters)",
      x = "PCA Component 1",
      y = "PCA Component 2",
      color = "Cluster") +
 theme_minimal() +
 theme(plot.title = element_text(size = 14, face = "bold"))
print(plot_hclust)

# Print cluster sizes
cat("\nHierarchical cluster sizes:\n")
## 
## Hierarchical cluster sizes:
print(table(hclust_clusters))
## hclust_clusters
##    1    2    3    4    5    6 
## 5483 6363 5216 5057 4245 3636
# Analyze cluster characteristics
cluster_stats <- df_sampled[continuous_features] %>%
 mutate(cluster = hclust_clusters) %>%
 group_by(cluster) %>%
 summarise(across(everything(), mean)) %>%
 round(3)
cat("\nHierarchical Cluster Characteristics:\n")
## 
## Hierarchical Cluster Characteristics:
print(cluster_stats)
## # A tibble: 6 × 4
##   cluster duration_scaled days_left_scaled price_log
##     <dbl>           <dbl>            <dbl>     <dbl>
## 1       1          -0.652            0.558      8.38
## 2       2          -0.017            0.503      8.91
## 3       3           0.987            0.507      8.8 
## 4       4          -0.348            0.527     10.8 
## 5       5           0.896            0.511     10.9 
## 6       6          -1.01             0.529      8.37
# Calculate centroids for hierarchical clustering in PCA space
centroids_pca <- data.frame(
 Cluster = 1:n_clusters,
 PC1 = tapply(X_pca[, 1], hclust_clusters, mean),
 PC2 = tapply(X_pca[, 2], hclust_clusters, mean)
)

# Transform centroids back to original feature space
rotation <- pca_result$rotation[, 1:2]
center <- pca_result$center
scale <- pca_result$scale
centroids_matrix <- matrix(c(centroids_pca$PC1, centroids_pca$PC2), ncol = 2)
centroids_original <- t(rotation %*% t(centroids_matrix))
centroids_original <- scale(centroids_original, center = -center, scale = 1/scale)
centroids_original_df <- as.data.frame(centroids_original)
colnames(centroids_original_df) <- names(center)
centroids_original_df$Cluster <- 1:n_clusters

# Print centroids
cat("\nCluster Centroids (PCA Space):\n")
## 
## Cluster Centroids (PCA Space):
print(round(centroids_pca, 3))
##   Cluster    PC1    PC2
## 1       1 -1.984 -0.105
## 2       2 -0.454  0.232
## 3       3 -0.350 -1.966
## 4       4  2.302  1.429
## 5       5  2.788 -0.722
## 6       6 -2.168  1.428
cat("\nCluster Centroids (Original Feature Space):\n")
## 
## Cluster Centroids (Original Feature Space):
print(round(centroids_original_df, 3))
##   duration_scaled days_left_scaled price_log airlineAir_India airlineAirAsia
## 1          -0.420            0.091    -0.994           -0.170          0.299
## 2          -0.216            0.023    -0.200           -0.122          0.044
## 3           0.803           -0.001    -0.383            0.606          0.240
## 4          -0.117           -0.094     1.293           -0.230         -0.473
## 5           0.971           -0.136     1.303            0.523         -0.336
## 6          -1.157            0.114    -0.921           -0.690          0.178
##   airlineGO_FIRST airlineIndigo airlineSpiceJet airlineVistara
## 1           0.294         0.477           0.154         -0.534
## 2           0.069         0.167           0.015         -0.071
## 3           0.036        -0.356           0.178         -0.479
## 4          -0.330        -0.258          -0.280          0.877
## 5          -0.420        -0.866          -0.149          0.578
## 6           0.333         0.869           0.049         -0.278
##   source_cityChennai source_cityDelhi source_cityHyderabad source_cityKolkata
## 1             -0.038            0.065               -0.010             -0.003
## 2              0.024           -0.013                0.004              0.002
## 3             -0.255            0.226               -0.051             -0.021
## 4              0.211           -0.220                0.045              0.018
## 5             -0.058            0.004               -0.007             -0.004
## 6              0.155           -0.099                0.028              0.013
##   source_cityMumbai departure_timeMorning departure_timeAfternoon
## 1            -0.015                -0.044                   0.111
## 2            -0.013                 0.058                   0.055
## 3             0.074                -0.523                  -0.205
## 4            -0.034                 0.397                   0.022
## 5             0.055                -0.168                  -0.256
## 6            -0.077                 0.360                   0.299
##   departure_timeEvening departure_timeLate_Night departure_timeNight
## 1                -0.008                    0.083              -0.098
## 2                -0.061                    0.021              -0.063
## 3                 0.450                    0.002               0.292
## 4                -0.294                   -0.088              -0.094
## 5                 0.212                   -0.122               0.276
## 6                -0.365                    0.101              -0.353
##   stopstwo_or_more stopszero arrival_timeMorning arrival_timeAfternoon
## 1            0.083     0.276              -0.097                 0.064
## 2           -0.015     0.127              -0.107                 0.030
## 3            0.271    -0.442               0.627                -0.108
## 4           -0.269     0.010              -0.320                 0.006
## 5           -0.002    -0.607               0.424                -0.143
## 6           -0.113     0.690              -0.616                 0.164
##   arrival_timeEvening arrival_timeLate_Night arrival_timeNight
## 1              -0.093                  0.220             0.002
## 2               0.003                  0.048             0.049
## 3              -0.204                  0.056            -0.371
## 4               0.234                 -0.267             0.247
## 5               0.047                 -0.302            -0.168
## 6               0.047                  0.227             0.296
##   destination_cityChennai destination_cityDelhi destination_cityHyderabad
## 1                  -0.038                 0.068                    -0.030
## 2                  -0.041                 0.065                    -0.017
## 3                   0.242                -0.364                     0.069
## 4                  -0.123                 0.173                    -0.015
## 5                   0.164                -0.263                     0.075
## 6                  -0.238                 0.372                    -0.092
##   destination_cityKolkata destination_cityMumbai classEconomy classBusiness
## 1                  -0.011                 -0.008        0.959        -0.959
## 2                  -0.039                  0.017        0.177        -0.177
## 3                   0.280                 -0.148        0.493        -0.493
## 4                  -0.177                  0.108       -1.330         1.330
## 5                   0.141                 -0.054       -1.203         1.203
## 6                  -0.235                  0.108        0.792        -0.792
##   Cluster
## 1       1
## 2       2
## 3       3
## 4       4
## 5       5
## 6       6

# DBSCAN

# Sample and prepare features
set.seed(42)
df_sampled <- Clean_Dataset1 %>% sample_n(30000)
continuous_features <- c('duration_scaled', 'days_left_scaled', 'price_log')
categorical_features <- c('airline', 'source_city', 'departure_time', 'stops',
                          'arrival_time', 'destination_city', 'class')

# One-hot encoding with explicit class encoding
encoded_categorical <- model.matrix(~ . - 1, data = df_sampled[categorical_features])
encoded_categorical_df <- as.data.frame(encoded_categorical)
encoded_categorical_df$classBusiness <- as.numeric(df_sampled$class == "Business")
encoded_categorical_df$classEconomy <- as.numeric(df_sampled$class == "Economy")
encoded_categorical_df <- encoded_categorical_df[!grepl("^classclass", names(encoded_categorical_df))]
data_encoded <- cbind(df_sampled[continuous_features], encoded_categorical_df)

# Impute and scale
preProcess_missingdata <- preProcess(data_encoded, method = "medianImpute")
data_encoded_imputed <- predict(preProcess_missingdata, data_encoded)
preProcess_scale <- preProcess(data_encoded_imputed, method = c("center", "scale"))
X_scaled <- predict(preProcess_scale, data_encoded_imputed)

# Perform PCA
pca_result <- prcomp(X_scaled, center = TRUE, scale. = TRUE)
X_pca <- pca_result$x[, 1:2]
X_pca_df <- data.frame(X_pca)

# DBSCAN Clustering with adjusted parameters
eps_values <- seq(0.2, 2, by = 0.1)
best_eps <- NULL
best_n_clusters <- 0
target_clusters <- 6

for(eps in eps_values) {
   dbscan_test <- dbscan(X_pca, eps = eps, minPts = 5)
   n_clusters <- length(unique(dbscan_test$cluster[dbscan_test$cluster != 0]))
   if(n_clusters >= target_clusters) {
       best_eps <- eps
       break
   }
}

# Run DBSCAN with optimized parameters
dbscan_result <- dbscan(X_pca, eps = best_eps, minPts = 5)
dbscan_clusters <- dbscan_result$cluster

# Calculate silhouette score (excluding noise points)
non_noise_points <- which(dbscan_clusters != 0)
dist_matrix <- dist(X_pca[non_noise_points, ])
clusters_no_noise <- dbscan_clusters[non_noise_points]
silhouette_dbscan <- mean(silhouette(clusters_no_noise, dist_matrix)[,"sil_width"])
cat(sprintf("\nDBSCAN Silhouette Score: %.4f\n", silhouette_dbscan))
## 
## DBSCAN Silhouette Score: 0.1822
# Create plot
plot_dbscan <- ggplot(X_pca_df, aes(x = PC1, y = PC2, color = factor(dbscan_clusters))) +
 geom_point(alpha = 0.6, size = 2) +
 scale_color_viridis_d() +
 labs(title = sprintf("DBSCAN Clustering (eps = %.2f, minPts = 5)", best_eps),
      x = "PCA Component 1",
      y = "PCA Component 2",
      color = "Cluster") +
 theme_minimal() +
 theme(plot.title = element_text(size = 14, face = "bold"))
print(plot_dbscan)

# Print cluster sizes
cat("\nDBSCAN cluster sizes (0 represents noise points):\n")
## 
## DBSCAN cluster sizes (0 represents noise points):
print(table(dbscan_clusters))
## dbscan_clusters
##     0     1     2     3     4     5     6 
##    46 20653  9275     5     7     9     5
# Analyze cluster characteristics (excluding noise points)
cluster_stats <- df_sampled[continuous_features][non_noise_points,] %>%
 mutate(cluster = clusters_no_noise) %>%
 group_by(cluster) %>%
 summarise(across(everything(), mean)) %>%
 round(3)
cat("\nDBSCAN Cluster Characteristics (excluding noise points):\n")
## 
## DBSCAN Cluster Characteristics (excluding noise points):
print(cluster_stats)
## # A tibble: 6 × 4
##   cluster duration_scaled days_left_scaled price_log
##     <dbl>           <dbl>            <dbl>     <dbl>
## 1       1          -0.105            0.523      8.65
## 2       2           0.223            0.52      10.8 
## 3       3          -1.05             0.304     10.8 
## 4       4          -1.33             0.533     10.2 
## 5       5           0.59             0.63       7.64
## 6       6          -1.35             0.696      8.41

# Add cluster assignments to encoded data
data_encoded_imputed$Cluster <- kmeans_clusters  # or hclust_clusters or dbscan_clusters

# Get continuous feature means by cluster
cluster_means <- data_encoded_imputed %>%
  group_by(Cluster) %>%
  summarise(across(all_of(continuous_features), mean))

# Identify categorical encoded columns
categorical_encoded_columns <- colnames(encoded_categorical_df)

# Get categorical feature means by cluster
cluster_modes <- data_encoded_imputed %>%
  group_by(Cluster) %>%
  summarise(across(all_of(categorical_encoded_columns), mean))

# Combine means and modes
cluster_summary <- bind_cols(
  cluster_means %>% select(-Cluster),
  cluster_modes
)

# Convert to matrix for heatmap
cluster_summary_matrix <- as.matrix(cluster_summary)
rownames(cluster_summary_matrix) <- paste("Cluster", 1:nrow(cluster_summary_matrix))

pheatmap(t(cluster_summary_matrix),
         main = "Cluster Feature Summary",
         display_numbers = TRUE,
         number_format = "%.2f",
         fontsize = 8,
         cluster_rows = FALSE,  
         cluster_cols = FALSE,  
         color = colorRampPalette(c("pink", "yellow", "lightblue"))(100))

# Print the summary table
print("Cluster Summary Statistics:")
## [1] "Cluster Summary Statistics:"
print(cluster_summary)
## # A tibble: 6 × 34
##   duration_scaled days_left_scaled price_log Cluster airlineAir_India
##             <dbl>            <dbl>     <dbl>   <int>            <dbl>
## 1          -0.979            0.535      8.37       1           0.0185
## 2           0.962            0.502      8.83       2           0.504 
## 3          -0.155            0.504      8.93       3           0.177 
## 4          -0.520            0.558      8.41       4           0.154 
## 5          -0.396            0.527     10.8        5           0.208 
## 6           0.822            0.512     10.9        6           0.497 
## # ℹ 29 more variables: airlineAirAsia <dbl>, airlineGO_FIRST <dbl>,
## #   airlineIndigo <dbl>, airlineSpiceJet <dbl>, airlineVistara <dbl>,
## #   source_cityChennai <dbl>, source_cityDelhi <dbl>,
## #   source_cityHyderabad <dbl>, source_cityKolkata <dbl>,
## #   source_cityMumbai <dbl>, departure_timeMorning <dbl>,
## #   departure_timeAfternoon <dbl>, departure_timeEvening <dbl>,
## #   departure_timeLate_Night <dbl>, departure_timeNight <dbl>, …

Modelling

# Liner Regression

# Feature preparation
continuous_features <- c('duration_scaled', 'days_left_scaled', 'price_log')
categorical_features <- c('airline', 'source_city', 'departure_time', 'stops',
                          'arrival_time', 'destination_city', 'class')

# Enhanced encoding with explicit class features
encoded_categorical <- model.matrix(~ . - 1, data = Clean_Dataset1[categorical_features])
encoded_categorical_df <- as.data.frame(encoded_categorical)
encoded_categorical_df$classBusiness <- as.numeric(Clean_Dataset1$class == "Business")
encoded_categorical_df$classEconomy <- as.numeric(Clean_Dataset1$class == "Economy")
encoded_categorical_df <- encoded_categorical_df[!grepl("^classclass", names(encoded_categorical_df))]
data_encoded <- cbind(Clean_Dataset1[continuous_features], encoded_categorical_df)

# Impute missing values
preProcess_missingdata <- preProcess(data_encoded, method = "medianImpute")
data_encoded_imputed <- predict(preProcess_missingdata, data_encoded)

# Standardize the data
preProcess_scale <- preProcess(data_encoded_imputed, method = c("center", "scale"))
X_scaled <- predict(preProcess_scale, data_encoded_imputed)

# Apply PCA
pca_result <- prcomp(X_scaled, center = TRUE, scale. = TRUE)
X_pca <- pca_result$x[, 1:2]

# Perform K-Means Clustering
set.seed(42)
n_clusters <- 6
kmeans_result <- kmeans(X_pca, centers = n_clusters, nstart = 25, iter.max = 100)
clusters <- kmeans_result$cluster

# Sample for regression
set.seed(42)
sample_size <- 30000
sample_indices <- sample(length(clusters), sample_size)

X_scaled_sampled <- X_scaled[sample_indices, ]
clusters_sampled <- clusters[sample_indices]

# Create cluster dummies
cluster_dummies <- model.matrix(~ factor(clusters_sampled) - 1)
colnames(cluster_dummies) <- paste0("cluster_", 1:ncol(cluster_dummies))

# Prepare regression data
regression_data <- data.frame(X_scaled_sampled, cluster_dummies)

# Train-test split
set.seed(42)
train_index <- createDataPartition(clusters_sampled, p = 0.7, list = FALSE)
train_data <- regression_data[train_index, ]
test_data <- regression_data[-train_index, ]

# Train models and make predictions
models <- list()
predictions <- matrix(0, nrow = nrow(test_data), ncol = ncol(cluster_dummies))

for(i in 1:ncol(cluster_dummies)) {
 formula <- as.formula(paste0("cluster_", i, " ~ . - ", 
                            paste0("cluster_", setdiff(1:ncol(cluster_dummies), i), 
                                   collapse = " - ")))
 
 models[[i]] <- lm(formula, data = train_data)
 predictions[,i] <- predict(models[[i]], newdata = test_data)
}

# Get predicted clusters
predicted_clusters <- max.col(predictions)
actual_clusters <- clusters_sampled[-train_index]

# Calculate accuracy
accuracy <- mean(predicted_clusters == actual_clusters)
cat("\nAccuracy:", round(accuracy, 4), "\n")
## 
## Accuracy: 0.8537
# Create confusion matrix
confusion_mat <- confusionMatrix(factor(predicted_clusters), factor(actual_clusters))
print(confusion_mat)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    1    2    3    4    5    6
##          1 1203    0  271   17    0    0
##          2    0 1310    0    2    0   64
##          3  187    0  885   23   96    0
##          4   99    0  129 1339   86    0
##          5    0    0   78  123 1652    0
##          6    0  133    0    8    0 1293
## 
## Overall Statistics
##                                                
##                Accuracy : 0.8537               
##                  95% CI : (0.8463, 0.861)      
##     No Information Rate : 0.2038               
##     P-Value [Acc > NIR] : < 0.00000000000000022
##                                                
##                   Kappa : 0.824                
##                                                
##  Mcnemar's Test P-Value : NA                   
## 
## Statistics by Class:
## 
##                      Class: 1 Class: 2 Class: 3 Class: 4 Class: 5 Class: 6
## Sensitivity            0.8079   0.9078  0.64930   0.8856   0.9008   0.9528
## Specificity            0.9616   0.9913  0.95992   0.9581   0.9719   0.9815
## Pos Pred Value         0.8068   0.9520  0.74307   0.8100   0.8915   0.9017
## Neg Pred Value         0.9619   0.9826  0.93877   0.9764   0.9745   0.9915
## Prevalence             0.1655   0.1604  0.15148   0.1680   0.2038   0.1508
## Detection Rate         0.1337   0.1456  0.09836   0.1488   0.1836   0.1437
## Detection Prevalence   0.1657   0.1529  0.13236   0.1837   0.2059   0.1594
## Balanced Accuracy      0.8848   0.9495  0.80461   0.9218   0.9364   0.9672
# Visualize results
plot_data <- data.frame(
 PC1 = X_pca[sample_indices, 1][-train_index],
 PC2 = X_pca[sample_indices, 2][-train_index],
 Predicted = factor(predicted_clusters),
 Actual = factor(actual_clusters)
)

# Create plots
prediction_plot <- ggplot(plot_data, aes(x = PC1, y = PC2)) +
 geom_point(aes(color = Predicted), alpha = 0.6, size = 3) +
 scale_color_viridis_d() +
 labs(title = "Liner Regression Predicted Clusters",
      x = "PCA Component 1", y = "PCA Component 2",
      color = "Predicted Cluster") +
 theme_minimal() +
 theme(plot.title = element_text(size = 14, face = "bold"))

actual_plot <- ggplot(plot_data, aes(x = PC1, y = PC2)) +
 geom_point(aes(color = Actual), alpha = 0.6, size = 3) +
 scale_color_viridis_d() +
 labs(title = "Actual Clusters",
      x = "PCA Component 1", y = "PCA Component 2",
      color = "Actual Cluster") +
 theme_minimal() +
 theme(plot.title = element_text(size = 14, face = "bold"))

grid.arrange(prediction_plot, actual_plot, ncol = 2)

# Model analysis
cat("\nModel Summaries for First 3 Clusters:\n")
## 
## Model Summaries for First 3 Clusters:
for(i in 1:3) {
 cat(sprintf("\nCluster %d Summary:\n", i))
 print(summary(models[[i]])$r.squared)
}
## 
## Cluster 1 Summary:
## [1] 0.6537257
## 
## Cluster 2 Summary:
## [1] 0.5816675
## 
## Cluster 3 Summary:
## [1] 0.4439028
r_squared_values <- sapply(models, function(m) summary(m)$r.squared)
cat("\nR-squared values for all models:\n")
## 
## R-squared values for all models:
print(round(r_squared_values, 4))
## [1] 0.6537 0.5817 0.4439 0.5034 0.6005 0.5674

# K-NN

# Sample and prepare with explicit class encoding
set.seed(42)
sample_size <- 30000
sample_indices <- sample(length(clusters), sample_size)

encoded_categorical <- model.matrix(~ . - 1, data = Clean_Dataset1[categorical_features])
encoded_categorical_df <- as.data.frame(encoded_categorical)
encoded_categorical_df$classBusiness <- as.numeric(Clean_Dataset1$class == "Business")
encoded_categorical_df$classEconomy <- as.numeric(Clean_Dataset1$class == "Economy")
encoded_categorical_df <- encoded_categorical_df[!grepl("^classclass", names(encoded_categorical_df))]

X_scaled_sampled <- X_scaled[sample_indices, ]
X_pca_sampled <- X_pca[sample_indices, ]
clusters_sampled <- clusters[sample_indices]

# Train-Test Split
train_index <- createDataPartition(clusters_sampled, p = 0.7, list = FALSE)
train_data <- data.frame(X_scaled_sampled[train_index,])
train_data$Cluster <- factor(clusters_sampled[train_index])
test_data <- data.frame(X_scaled_sampled[-train_index,])
test_data$Cluster <- factor(clusters_sampled[-train_index])

# Train KNN
ctrl <- trainControl(method = "cv", number = 5)
knn_model <- train(Cluster ~ ., 
                  data = train_data,
                  method = "knn",
                  trControl = ctrl,
                  tuneGrid = data.frame(k = 5))

knn_pred <- predict(knn_model, newdata = test_data)
confusion_mat <- confusionMatrix(knn_pred, test_data$Cluster)

cat("\nConfusion Matrix and Statistics:\n")
## 
## Confusion Matrix and Statistics:
print(confusion_mat)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction    1    2    3    4    5    6
##          1 1460    0   45   44    0    1
##          2    0 1361    1    0    0   48
##          3   67    0 1255   64   23    3
##          4   15    2   17 1381   43   11
##          5    0    2   27   36 1768    0
##          6    4   21    1    4    0 1294
## 
## Overall Statistics
##                                                
##                Accuracy : 0.9468               
##                  95% CI : (0.9419, 0.9513)     
##     No Information Rate : 0.2038               
##     P-Value [Acc > NIR] : < 0.00000000000000022
##                                                
##                   Kappa : 0.936                
##                                                
##  Mcnemar's Test P-Value : NA                   
## 
## Statistics by Class:
## 
##                      Class: 1 Class: 2 Class: 3 Class: 4 Class: 5 Class: 6
## Sensitivity            0.9444   0.9820   0.9324   0.9032   0.9640   0.9536
## Specificity            0.9879   0.9936   0.9795   0.9882   0.9909   0.9961
## Pos Pred Value         0.9419   0.9652   0.8888   0.9401   0.9645   0.9773
## Neg Pred Value         0.9885   0.9967   0.9880   0.9803   0.9908   0.9918
## Prevalence             0.1718   0.1540   0.1496   0.1699   0.2038   0.1508
## Detection Rate         0.1623   0.1513   0.1395   0.1535   0.1965   0.1438
## Detection Prevalence   0.1723   0.1567   0.1569   0.1633   0.2037   0.1471
## Balanced Accuracy      0.9661   0.9878   0.9559   0.9457   0.9775   0.9748
cat("\nAccuracy Score:", confusion_mat$overall["Accuracy"], "\n")
## 
## Accuracy Score: 0.9467659
# Visualization
plot_data <- data.frame(
 PC1 = X_pca_sampled[-train_index, 1],
 PC2 = X_pca_sampled[-train_index, 2],
 Predicted = knn_pred,
 Actual = factor(clusters_sampled[-train_index])
)

prediction_plot <- ggplot(plot_data, aes(x = PC1, y = PC2)) +
 geom_point(aes(color = Predicted), alpha = 0.6, size = 3) +
 scale_color_viridis_d() +
 labs(title = "K-NN Predicted Clusters (Test Data)",
      x = "PCA Component 1",
      y = "PCA Component 2",
      color = "Predicted Cluster") +
 theme_minimal() +
 theme(plot.title = element_text(size = 14, face = "bold"))

actual_plot <- ggplot(plot_data, aes(x = PC1, y = PC2)) +
 geom_point(aes(color = Actual), alpha = 0.6, size = 3) +
 scale_color_viridis_d() +
 labs(title = "Actual Clusters (Test Data)",
      x = "PCA Component 1",
      y = "PCA Component 2",
      color = "Actual Cluster") +
 theme_minimal() +
 theme(plot.title = element_text(size = 14, face = "bold"))

grid.arrange(prediction_plot, actual_plot, ncol = 2)

cat("\nDetailed Metrics per Class:\n")
## 
## Detailed Metrics per Class:
print(confusion_mat$byClass)
##          Sensitivity Specificity Pos Pred Value Neg Pred Value Precision
## Class: 1   0.9443726   0.9879227      0.9419355      0.9884533 0.9419355
## Class: 2   0.9819625   0.9935628      0.9652482      0.9967053 0.9652482
## Class: 3   0.9323923   0.9794825      0.8888102      0.9880042 0.8888102
## Class: 4   0.9032047   0.9882180      0.9400953      0.9803427 0.9400953
## Class: 5   0.9640131   0.9909269      0.9645390      0.9907886 0.9645390
## Class: 6   0.9535741   0.9960738      0.9773414      0.9917905 0.9773414
##             Recall        F1 Prevalence Detection Rate Detection Prevalence
## Class: 1 0.9443726 0.9431525  0.1718160      0.1622583            0.1722605
## Class: 2 0.9819625 0.9735336  0.1540342      0.1512558            0.1567015
## Class: 3 0.9323923 0.9100798  0.1495888      0.1394754            0.1569238
## Class: 4 0.9032047 0.9212809  0.1699267      0.1534786            0.1632585
## Class: 5 0.9640131 0.9642760  0.2038231      0.1964881            0.2037119
## Class: 6 0.9535741 0.9653115  0.1508113      0.1438097            0.1471438
##          Balanced Accuracy
## Class: 1         0.9661476
## Class: 2         0.9877626
## Class: 3         0.9559374
## Class: 4         0.9457113
## Class: 5         0.9774700
## Class: 6         0.9748239
cat("\nOverall Statistics:\n")
## 
## Overall Statistics:
print(confusion_mat$overall)
##       Accuracy          Kappa  AccuracyLower  AccuracyUpper   AccuracyNull 
##      0.9467659      0.9359658      0.9419251      0.9513149      0.2038231 
## AccuracyPValue  McnemarPValue 
##      0.0000000            NaN