#Introduction
In the era of rapid urbanization, transportation services play a crucial role in facilitating movement within cities. Predicting the cost of taxi rides is an essential task for both passengers and service providers. For passengers, it allows for better budgeting and informed decision-making when selecting transportation options. For taxi service providers, accurate fare predictions can improve pricing strategies, enhance customer satisfaction, and optimize operational efficiency.
This project aims to tackle the problem of taxi fare prediction using machine learning techniques. The goal is to predict the fare amount based on input features such as the pickup and drop-off locations, the time of day, the number of passengers.
##Significance of the Problem
Accurate taxi fare prediction is important for:
Passengers:Provides transparency in pricing and helps in planning and budgeting for travel expenses.
Service Providers: Enables the development of dynamic pricing models, improves customer experience through fare estimation features and helps in fleet optimization by forecasting demand and planning routes.
##Challenges
The dataset introduces several challenges that align with real-world data complexities:
Noise in Data: Predictors include irrelevant or weakly correlated features, requiring effective feature selection methods.
Data Transformations:Values have been rescaled and shifted, necessitating careful preprocessing and exploratory data analysis.
Synthetic Predictors:The dataset includes engineered features whose relevance must be assessed during the analysis.
Real-World Irregularities:Outliers, missing values, and inconsistencies in geospatial and temporal data must be addressed.
##Dataset Description
The dataset contains the following columns:
id: Unique identifier for each ride record.
dropoff_latitude & dropoff_longitude: The geographic coordinates where the taxi ride ended.
fare_amount: The target variable representing the cost of the taxi ride in dollars.
feat01 - feat10: Synthetic features added to the dataset.
key: A unique string identifier for each record, combining the pickup datetime and a unique integer.
passenger_count: The number of passengers in the taxi during the ride.
pickup_datetime: The timestamp indicating when the taxi ride started.
pickup_latitude & pickup_longitude: The geographic coordinates where the taxi ride began. #loading packages
#Install libraries
library(dplyr)
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library(tidyr)
library(data.table)
##
## Attaching package: 'data.table'
## The following objects are masked from 'package:dplyr':
##
## between, first, last
library(ggplot2)
library(cowplot)
library(gridExtra)
##
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
##
## combine
library(corrplot)
## corrplot 0.95 loaded
library(scales)
library(caret)
## Loading required package: lattice
library(randomForest)
## randomForest 4.7-1.2
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:gridExtra':
##
## combine
## The following object is masked from 'package:ggplot2':
##
## margin
## The following object is masked from 'package:dplyr':
##
## combine
library(rpart)
library(xgboost)
##
## Attaching package: 'xgboost'
## The following object is masked from 'package:dplyr':
##
## slice
library(gbm)
## Loaded gbm 2.2.2
## This version of gbm is no longer under development. Consider transitioning to gbm3, https://github.com/gbm-developers/gbm3
library(e1071)
library(Metrics)
##
## Attaching package: 'Metrics'
## The following objects are masked from 'package:caret':
##
## precision, recall
library(MLmetrics)
##
## Attaching package: 'MLmetrics'
## The following objects are masked from 'package:caret':
##
## MAE, RMSE
## The following object is masked from 'package:base':
##
## Recall
library(ROCR)
library(pROC)
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
## The following object is masked from 'package:Metrics':
##
## auc
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
library(nnet)
library(keras)
library(tensorflow)
##
## Attaching package: 'tensorflow'
## The following object is masked from 'package:caret':
##
## train
library(knitr)
library(rmarkdown)
library(kableExtra)
##
## Attaching package: 'kableExtra'
## The following object is masked from 'package:dplyr':
##
## group_rows
library(here)
## here() starts at /Users/paidamoyo/MACHINE OUTLIERS
library(rpart.plot)
#Loading dataset
taxi_data <- read.csv("~/Desktop/machine learning/r3.csv")
str(taxi_data)
## 'data.frame': 90000 obs. of 19 variables:
## $ id : int 1 2 3 4 5 6 7 8 9 10 ...
## $ dropoff_latitude : num 40.8 40.8 40.8 40.8 40.7 ...
## $ dropoff_longitude: num -74 -74 -74 -74 -74 ...
## $ fare_amount : num 16.6 23.2 18.2 23.2 18.2 ...
## $ feat01 : num 0.59368 0.00444 0.55456 0.69443 0.93035 ...
## $ feat02 : num 0.277 0.408 0.542 0.448 0.444 ...
## $ feat03 : num 0.78 0.413 0.213 0.101 0.981 ...
## $ feat04 : num 0.207 0.526 0.826 0.207 0.99 ...
## $ feat05 : num 0.0118 0.6049 0.0816 0.2521 0.4238 ...
## $ feat06 : num 0.381 0.326 0.584 0.495 0.725 ...
## $ feat07 : num 0.7739 0.0959 0.6505 0.9237 0.2969 ...
## $ feat08 : num 0.5504 0.0215 0.8157 0.6024 0.2126 ...
## $ feat09 : num 0.352 0.396 0.33 0.296 0.264 ...
## $ feat10 : num 0.226 0.489 0.703 0.226 0.229 ...
## $ key : chr "2014-07-22 12:28:44.0000002" "2015-02-14 14:14:50.0000005" "2015-04-30 07:19:42.0000001" "2014-03-23 02:10:00.000000176" ...
## $ passenger_count : int 1 1 1 1 2 1 1 1 2 1 ...
## $ pickup_datetime : chr "2014-07-22 12:28:44 UTC" "2015-02-14 14:14:50 UTC" "2015-04-30 07:19:42 UTC" "2014-03-23 02:10:00 UTC" ...
## $ pickup_latitude : num 40.8 40.8 40.8 40.8 40.8 ...
## $ pickup_longitude : num -74 -74 -74 -74 -74 ...
summary(taxi_data)
## id dropoff_latitude dropoff_longitude fare_amount
## Min. : 1 Min. : 0.00 Min. :-171.80 Min. : 4.50
## 1st Qu.:22501 1st Qu.:40.73 1st Qu.: -73.99 1st Qu.: 17.15
## Median :45000 Median :40.75 Median : -73.98 Median : 20.45
## Mean :45000 Mean :39.99 Mean : -72.60 Mean : 24.22
## 3rd Qu.:67500 3rd Qu.:40.77 3rd Qu.: -73.96 3rd Qu.: 25.95
## Max. :90000 Max. :69.70 Max. : 0.00 Max. :408.56
## feat01 feat02 feat03 feat04
## Min. :0.0000347 Min. :0.0000 Min. :0.0000054 Min. :0.0000013
## 1st Qu.:0.2534391 1st Qu.:0.3474 1st Qu.:0.2512272 1st Qu.:0.2494841
## Median :0.5006765 Median :0.4196 Median :0.4983004 Median :0.4996805
## Mean :0.5012469 Mean :0.4209 Mean :0.4995880 Mean :0.4991737
## 3rd Qu.:0.7507510 3rd Qu.:0.4935 3rd Qu.:0.7495345 3rd Qu.:0.7479510
## Max. :0.9999880 Max. :1.0000 Max. :0.9999581 Max. :0.9999977
## feat05 feat06 feat07
## Min. :0.0000036 Min. :0.0000044 Min. :0.0000016
## 1st Qu.:0.2512522 1st Qu.:0.2476465 1st Qu.:0.2494061
## Median :0.5009581 Median :0.4961104 Median :0.4982960
## Mean :0.4999889 Mean :0.4978578 Mean :0.4996831
## 3rd Qu.:0.7484398 3rd Qu.:0.7481329 3rd Qu.:0.7517410
## Max. :0.9999558 Max. :0.9999946 Max. :0.9999630
## feat08 feat09 feat10 key
## Min. :0.0000107 Min. :0.0000 Min. :0.0000383 Length:90000
## 1st Qu.:0.2509321 1st Qu.:0.2987 1st Qu.:0.2506165 Class :character
## Median :0.5014939 Median :0.3570 Median :0.5021406 Mode :character
## Mean :0.5005764 Mean :0.3584 Mean :0.5011385
## 3rd Qu.:0.7510181 3rd Qu.:0.4165 3rd Qu.:0.7519337
## Max. :0.9999996 Max. :1.0000 Max. :0.9999946
## passenger_count pickup_datetime pickup_latitude pickup_longitude
## Min. :0.000 Length:90000 Min. : 0.00 Min. :-81.23
## 1st Qu.:1.000 Class :character 1st Qu.:40.74 1st Qu.:-73.99
## Median :1.000 Mode :character Median :40.75 Median :-73.98
## Mean :1.693 Mean :40.00 Mean :-72.61
## 3rd Qu.:2.000 3rd Qu.:40.77 3rd Qu.:-73.97
## Max. :6.000 Max. :41.40 Max. : 0.00
##Data Cleaning
# Converting pickup_datetime to datetime format
taxi_data$pickup_datetime <- as.POSIXct(taxi_data$pickup_datetime, format = "%Y-%m-%d %H:%M:%S", tz = "UTC")
colSums(is.na(taxi_data))
## id dropoff_latitude dropoff_longitude fare_amount
## 0 0 0 0
## feat01 feat02 feat03 feat04
## 0 0 0 0
## feat05 feat06 feat07 feat08
## 0 0 0 0
## feat09 feat10 key passenger_count
## 0 0 0 0
## pickup_datetime pickup_latitude pickup_longitude
## 0 0 0
##Visualising the targert variable
The histogram shows that the fare_amount has a highly right-skewed distribution, with the majority of fares clustered in the lower range, likely between $5 and $25. This suggests that most taxi rides are short trips. There are some high fare values above $100, which could be outliers or represent longer, expensive trips. The fare distribution aligns with typical taxi ride patterns, where shorter trips are more common, and the frequency of fares decreases as the fare amount increases. This trend is expected, as long-distance trips are less frequent compared to short, routine taxi rides.
library(ggplot2)
ggplot(taxi_data, aes(x = fare_amount)) +
geom_histogram(bins = 30, fill = "blue", alpha = 0.7) +
theme_minimal() +
labs(title = "Distribution of Fare Amount", x = "Fare Amount", y = "Frequency")
###Pickup and Droppoff Locations
The scatter plot reveals that pickup locations are densely clustered around certain geographic regions, likely corresponding to urban centers or high-demand areas such as airports and city hubs.A few outliers exist far from the main clusters, which could indicate long-distance trips, data entry errors, or rides originating from remote locations.
Dropoff Locations:Similar to pickup points, dropoff locations are clustered in specific regions where rides often terminate.Outliers are also observed, indicating rides ending in unexpected or distant areas, which might require further investigation.
# pickup locations
ggplot(taxi_data, aes(x = pickup_longitude, y = pickup_latitude)) +
geom_point(alpha = 0.1, color = "blue") +
theme_minimal() +
labs(title = "Pickup Locations",
x = "Longitude",
y = "Latitude")
# dropoff locations
ggplot(taxi_data, aes(x = dropoff_longitude, y = dropoff_latitude)) +
geom_point(alpha = 0.1, color = "red") +
theme_minimal() +
labs(title = "Dropoff Locations",
x = "Longitude",
y = "Latitude")
##Outliers
remove_outliers <- function(data, cols) {
for (col in cols) {
Q1 <- quantile(data[[col]], 0.25, na.rm = TRUE) # 1st quartile
Q3 <- quantile(data[[col]], 0.75, na.rm = TRUE) # 3rd quartile
IQR <- Q3 - Q1 # Interquartile Range
# Filtering data within 1.5 * IQR from Q1 and Q3
data <- data[data[[col]] >= (Q1 - 1.5 * IQR) & data[[col]] <= (Q3 + 1.5 * IQR), ]
}
return(data)
}
numeric_cols <- sapply(taxi_data, is.numeric) # Identify numeric columns
taxi_data <- remove_outliers(taxi_data, cols = names(taxi_data)[numeric_cols])
The dropoff and pickup points now show a clear clustering pattern, with most rides starting and terminating in specific urban areas.
The cleaned data now better represents the typical pickup locations and dropoff locations, making the analysis more reliable and the model less prone to overfitting.
# pickup locations
ggplot(taxi_data, aes(x = pickup_longitude, y = pickup_latitude)) +
geom_point(alpha = 0.1, color = "blue") +
theme_minimal() +
labs(title = "Pickup Locations",
x = "Longitude",
y = "Latitude")
# dropoff locations
ggplot(taxi_data, aes(x = dropoff_longitude, y = dropoff_latitude)) +
geom_point(alpha = 0.1, color = "red") +
theme_minimal() +
labs(title = "Dropoff Locations",
x = "Longitude",
y = "Latitude")
After removing outliers, the histogram of fare amounts reveals a more refined distribution. The fare values are primarily concentrated in the range of approximately $5 to $30, which aligns with typical fare amounts for regular taxi rides in urban settings. The removal of outliers has eliminated extremely high fare amounts, making the data less skewed and more representative of common fare distributions.
ggplot(taxi_data, aes(x = fare_amount)) +
geom_histogram(bins = 30, fill = "blue", alpha = 0.7) +
theme_minimal() +
labs(title = "Distribution of Fare Amount", x = "Fare Amount", y = "Frequency")
##Corelation Key Features: feat02 and feat09 are the most correlated with fare_amount. Weak Predictors: Features like passenger_count, feat06, feat01, and others show minimal correlation, suggesting limited predictive value. Latitude and longitude have weak correlations but they are important for the analysis.
numeric_features <- taxi_data[, sapply(taxi_data, is.numeric)]
# correlation matrix
cor_matrix <- cor(numeric_features, use = "complete.obs")
fare_correlations <- cor_matrix["fare_amount", ]
fare_correlations <- sort(fare_correlations, decreasing = TRUE)
# Displaying correlations
print(fare_correlations)
## fare_amount feat02 feat09 passenger_count
## 1.0000000000 0.2067938089 0.1919435413 0.0251295670
## feat06 feat01 feat05 feat04
## 0.0064393331 0.0053301239 0.0041920936 0.0002712407
## feat10 dropoff_longitude feat08 feat07
## -0.0008267462 -0.0028571376 -0.0032224575 -0.0033964407
## feat03 id pickup_longitude pickup_latitude
## -0.0051321869 -0.0108852506 -0.0823749095 -0.1023880823
## dropoff_latitude
## -0.1073983375
# Visualize the correlation matrix
corrplot(cor_matrix, method = "circle", type = "lower", tl.col = "black", tl.cex = 0.8)
##Standardization
#standardizing numeric columns
columns_to_standardize <- c("dropoff_latitude", "dropoff_longitude", "feat01", "feat02",
"feat03", "feat04", "feat05", "feat06", "feat07",
"feat08", "feat09", "feat10", "pickup_latitude",
"pickup_longitude")
taxi_data[columns_to_standardize] <- scale(taxi_data[columns_to_standardize])
str(taxi_data)
## 'data.frame': 64830 obs. of 19 variables:
## $ id : int 1 2 3 4 5 6 7 9 10 11 ...
## $ dropoff_latitude : num 0.146 1.089 0.278 1.681 -0.432 ...
## $ dropoff_longitude: num 0.3546 -0.2841 0.0269 0.7589 0.2378 ...
## $ fare_amount : num 16.6 23.2 18.2 23.2 18.2 ...
## $ feat01 : num 0.321 -1.726 0.185 0.671 1.49 ...
## $ feat02 : num -1.3029 -0.0161 1.3077 0.3785 0.3388 ...
## $ feat03 : num 0.974 -0.303 -0.995 -1.387 1.674 ...
## $ feat04 : num -1.0118 0.0953 1.1322 -1.0133 1.7017 ...
## $ feat05 : num -1.692 0.365 -1.45 -0.859 -0.263 ...
## $ feat06 : num -0.40148 -0.59295 0.2994 -0.00979 0.78623 ...
## $ feat07 : num 0.946 -1.398 0.52 1.465 -0.703 ...
## $ feat08 : num 0.174 -1.659 1.093 0.354 -0.996 ...
## $ feat09 : num 0.0456 0.5905 -0.222 -0.6389 -1.0274 ...
## $ feat10 : num -0.952 -0.0419 0.7006 -0.9537 -0.9423 ...
## $ key : chr "2014-07-22 12:28:44.0000002" "2015-02-14 14:14:50.0000005" "2015-04-30 07:19:42.0000001" "2014-03-23 02:10:00.000000176" ...
## $ passenger_count : int 1 1 1 1 2 1 1 2 1 1 ...
## $ pickup_datetime : POSIXct, format: "2014-07-22 12:28:44" "2015-02-14 14:14:50" ...
## $ pickup_latitude : num 0.60329 -0.08839 -0.09787 -0.00952 0.28573 ...
## $ pickup_longitude : num 0.861 -0.526 -0.548 -0.213 0.303 ...
#Feature Selection
Based on the correlation values: We are goin to retain Highly Correlated Features and drop features with low corelation. These are “feat07”, “feat01”, “feat04”, “feat03”, “feat08”, “feat05”, “id”,and “feat10” as they have almost no correlation with fare_amount.
#correlation matrix for numeric features
numeric_features <- taxi_data[, sapply(taxi_data, is.numeric)] # Select only numeric columns
cor_matrix <- cor(numeric_features, use = "complete.obs")
cor_target <- cor_matrix["fare_amount", ]
cor_sorted <- sort(cor_target, decreasing = TRUE)
#sorted correlation values
cat("Correlation values with fare_amount:\n")
## Correlation values with fare_amount:
print(cor_sorted)
## fare_amount feat02 feat09 passenger_count
## 1.0000000000 0.2067938089 0.1919435413 0.0251295670
## feat06 feat01 feat05 feat04
## 0.0064393331 0.0053301239 0.0041920936 0.0002712407
## feat10 dropoff_longitude feat08 feat07
## -0.0008267462 -0.0028571376 -0.0032224575 -0.0033964407
## feat03 id pickup_longitude pickup_latitude
## -0.0051321869 -0.0108852506 -0.0823749095 -0.1023880823
## dropoff_latitude
## -0.1073983375
cor_data <- data.frame(Feature = names(cor_sorted), Correlation = cor_sorted)
ggplot(cor_data, aes(x = reorder(Feature, Correlation), y = Correlation)) +
geom_bar(stat = "identity", fill = "steelblue") +
coord_flip() +
labs(title = "Correlation of Features with fare_amount",
x = "Features",
y = "Correlation") +
theme_minimal()
columns_to_drop <- c("feat07", "feat01", "feat04", "feat03", "feat08", "feat05", "id", "feat10")
taxi_data <- taxi_data[, !(names(taxi_data) %in% columns_to_drop)]
str(taxi_data)
## 'data.frame': 64830 obs. of 11 variables:
## $ dropoff_latitude : num 0.146 1.089 0.278 1.681 -0.432 ...
## $ dropoff_longitude: num 0.3546 -0.2841 0.0269 0.7589 0.2378 ...
## $ fare_amount : num 16.6 23.2 18.2 23.2 18.2 ...
## $ feat02 : num -1.3029 -0.0161 1.3077 0.3785 0.3388 ...
## $ feat06 : num -0.40148 -0.59295 0.2994 -0.00979 0.78623 ...
## $ feat09 : num 0.0456 0.5905 -0.222 -0.6389 -1.0274 ...
## $ key : chr "2014-07-22 12:28:44.0000002" "2015-02-14 14:14:50.0000005" "2015-04-30 07:19:42.0000001" "2014-03-23 02:10:00.000000176" ...
## $ passenger_count : int 1 1 1 1 2 1 1 2 1 1 ...
## $ pickup_datetime : POSIXct, format: "2014-07-22 12:28:44" "2015-02-14 14:14:50" ...
## $ pickup_latitude : num 0.60329 -0.08839 -0.09787 -0.00952 0.28573 ...
## $ pickup_longitude : num 0.861 -0.526 -0.548 -0.213 0.303 ...
#Data Splitting
set.seed(123456789)
# Splitting the data into training (70%) and testing (30%)
training_indices <- createDataPartition(taxi_data$fare_amount, p = 0.7, list = FALSE)
taxi_train <- taxi_data[training_indices, ]
taxi_test <- taxi_data[-training_indices, ]
cat("Training set size:", nrow(taxi_train), "\n")
## Training set size: 45383
cat("Test set size:", nrow(taxi_test), "\n")
## Test set size: 19447
##Saving key column separately
test_keys <- data.frame(key = taxi_test$key) # Ensure it's saved as a data frame with the correct column name
write.csv(test_keys, "test_keys.csv", row.names = FALSE)
cat("Test keys saved successfully as 'test_keys.csv'\n")
## Test keys saved successfully as 'test_keys.csv'
str(test_keys)
## 'data.frame': 19447 obs. of 1 variable:
## $ key: chr "2015-03-15 21:33:50.0000004" "2015-01-19 16:40:10.0000002" "2015-01-10 11:18:42.0000004" "2014-06-17 11:52:00.00000052" ...
###Removing the key column
taxi_train <- taxi_train[, !names(taxi_train) %in% "key"]
taxi_test <- taxi_test[, !names(taxi_test) %in% "key"]
taxi_data <- taxi_data[, !names(taxi_data) %in% "key"]
# Step 3: Confirm that the 'key' column has been removed
cat("Dimensions of taxi_train after removal:", dim(taxi_train), "\n")
## Dimensions of taxi_train after removal: 45383 10
cat("Dimensions of taxi_test after removal:", dim(taxi_test), "\n")
## Dimensions of taxi_test after removal: 19447 10
#Models
##Random forest
First random forest model
rf_model <- randomForest(
fare_amount ~ .,
data = taxi_train,
ntree = 100, # Number of trees
mtry = 3, # Number of predictors sampled at each split
importance = TRUE
)
rf_train_preds <- predict(rf_model, newdata = taxi_train)
rf_test_preds <- predict(rf_model, newdata = taxi_test)
rf_train_metrics <- data.frame(
RMSE = RMSE(rf_train_preds, taxi_train$fare_amount),
MAE = MAE(rf_train_preds, taxi_train$fare_amount),
R2 = cor(rf_train_preds, taxi_train$fare_amount)^2
)
rf_test_metrics <- data.frame(
RMSE = RMSE(rf_test_preds, taxi_test$fare_amount),
MAE = MAE(rf_test_preds, taxi_test$fare_amount),
R2 = cor(rf_test_preds, taxi_test$fare_amount)^2
)
# Print Metrics
print("Performance on Training Set (Random Forest):")
## [1] "Performance on Training Set (Random Forest):"
print(rf_train_metrics)
## RMSE MAE R2
## 1 1.198399 0.8900356 0.9569847
print("Performance on Testing Set (Random Forest):")
## [1] "Performance on Testing Set (Random Forest):"
print(rf_test_metrics)
## RMSE MAE R2
## 1 2.69144 2.048448 0.6910325
# Variable Importance
print("Variable Importance (Random Forest):")
## [1] "Variable Importance (Random Forest):"
print(importance(rf_model))
## %IncMSE IncNodePurity
## dropoff_latitude 120.7932426 232421.196
## dropoff_longitude 87.1414272 145423.179
## feat02 26.4330407 77625.354
## feat06 -0.2514204 44720.056
## feat09 17.8731927 75155.414
## passenger_count 2.3641040 7289.917
## pickup_datetime 2.1670321 46130.402
## pickup_latitude 99.7058647 201995.414
## pickup_longitude 76.4515250 140420.189
varImpPlot(rf_model, main = "Random Forest Variable Importance")
rf_combined_metrics <- data.frame(
Set = c("Training", "Testing"),
RMSE = c(rf_train_metrics$RMSE, rf_test_metrics$RMSE),
MAE = c(rf_train_metrics$MAE, rf_test_metrics$MAE),
R2 = c(rf_train_metrics$R2, rf_test_metrics$R2)
)
print(rf_combined_metrics)
## Set RMSE MAE R2
## 1 Training 1.198399 0.8900356 0.9569847
## 2 Testing 2.691440 2.0484480 0.6910325
# Visualize the table using kable
kable(
rf_combined_metrics,
caption = "Performance Metrics for Random Forest",
col.names = c("Dataset", "RMSE", "MAE", "R²")
)
| Dataset | RMSE | MAE | R² |
|---|---|---|---|
| Training | 1.198399 | 0.8900356 | 0.9569847 |
| Testing | 2.691440 | 2.0484480 | 0.6910325 |
Features like passenger_count and pickup_datetime have minimal impact on the model. Dropping them may simplify the model without affecting performance since there is a large gap between train matrics and perfomance matrics .
taxi_train <- taxi_train[, !(names(taxi_train) %in% c("passenger_count", "pickup_datetime"))]
taxi_test <- taxi_test[, !(names(taxi_test) %in% c("passenger_count", "pickup_datetime"))]
# Retrain the Random Forest model
rf_model <- randomForest(
fare_amount ~ .,
data = taxi_train,
ntree = 100,
mtry = 3,
importance = TRUE
)
The mtry parameter controls how many features are sampled at each split henceforth testing different values can improve the model.This evaluates the performance of the Random Forest model using different values of mtry. This helps identify the mtry value that minimizes the RMSE and produces the best model.
#different mtry values
for (mtry_val in 2:5) {
rf_model <- randomForest(
fare_amount ~ .,
data = taxi_train,
ntree = 100,
mtry = mtry_val,
importance = TRUE
)
print(paste("mtry:", mtry_val))
print(RMSE(predict(rf_model, taxi_test), taxi_test$fare_amount))
}
## [1] "mtry: 2"
## [1] 2.669941
## [1] "mtry: 3"
## [1] 2.588312
## [1] "mtry: 4"
## [1] 2.549705
## [1] "mtry: 5"
## [1] 2.530348
We trained the Random Forest with increased number of trees and mtry.
rf_model <- randomForest(
fare_amount ~ .,
data = taxi_train,
ntree = 300, # Increased trees for stability
mtry = 5, # Default mtry (can adjust during CV)
importance = TRUE # Track variable importance
)
# Evaluate performance
rf_train_preds <- predict(rf_model, newdata = taxi_train)
rf_test_preds <- predict(rf_model, newdata = taxi_test)
# Training metrics
rf_train_metrics <- data.frame(
RMSE = RMSE(rf_train_preds, taxi_train$fare_amount),
MAE = MAE(rf_train_preds, taxi_train$fare_amount),
R2 = cor(rf_train_preds, taxi_train$fare_amount)^2
)
# Testing metrics
rf_test_metrics <- data.frame(
RMSE = RMSE(rf_test_preds, taxi_test$fare_amount),
MAE = MAE(rf_test_preds, taxi_test$fare_amount),
R2 = cor(rf_test_preds, taxi_test$fare_amount)^2
)
# Print metrics
print("Training Performance:")
## [1] "Training Performance:"
print(rf_train_metrics)
## RMSE MAE R2
## 1 1.081727 0.7899513 0.9600158
print("Testing Performance:")
## [1] "Testing Performance:"
print(rf_test_metrics)
## RMSE MAE R2
## 1 2.523936 1.882638 0.7160709
rf_combined_metrics <- data.frame(
Dataset = c("Training", "Testing"),
RMSE = c(rf_train_metrics$RMSE, rf_test_metrics$RMSE),
MAE = c(rf_train_metrics$MAE, rf_test_metrics$MAE),
R2 = c(rf_train_metrics$R2, rf_test_metrics$R2)
)
print(rf_combined_metrics)
## Dataset RMSE MAE R2
## 1 Training 1.081727 0.7899513 0.9600158
## 2 Testing 2.523936 1.8826380 0.7160709
kable(
rf_combined_metrics,
caption = "Performance Metrics for Random Forest",
col.names = c("Dataset", "RMSE", "MAE", "R²")
)
| Dataset | RMSE | MAE | R² |
|---|---|---|---|
| Training | 1.081727 | 0.7899513 | 0.9600158 |
| Testing | 2.523936 | 1.8826380 | 0.7160709 |
# Variable Importance
print("Variable Importance (Random Forest):")
## [1] "Variable Importance (Random Forest):"
print(importance(rf_model))
## %IncMSE IncNodePurity
## dropoff_latitude 351.616742 261287.57
## dropoff_longitude 209.389452 145312.63
## feat02 33.765441 69927.10
## feat06 -1.377709 38355.40
## feat09 28.036997 65191.02
## pickup_latitude 321.358392 259351.41
## pickup_longitude 189.863615 143265.73
varImpPlot(rf_model, main = "Random Forest Variable Importance")
To ensures the model generalizes well we used k-fold cross-validation to optimize the hyperparameter mtry. In this case the mtry was 5 .
set.seed(12345)
control <- trainControl(
method = "cv",
number = 5,
verboseIter = TRUE
)
rf_grid <- expand.grid(
mtry = 2:5
)
rf_cv_model <- caret::train(
fare_amount ~ .,
data = taxi_train,
method = "rf",
trControl = control,
tuneGrid = rf_grid,
ntree = 300
)
## + Fold1: mtry=2
## - Fold1: mtry=2
## + Fold1: mtry=3
## - Fold1: mtry=3
## + Fold1: mtry=4
## - Fold1: mtry=4
## + Fold1: mtry=5
## - Fold1: mtry=5
## + Fold2: mtry=2
## - Fold2: mtry=2
## + Fold2: mtry=3
## - Fold2: mtry=3
## + Fold2: mtry=4
## - Fold2: mtry=4
## + Fold2: mtry=5
## - Fold2: mtry=5
## + Fold3: mtry=2
## - Fold3: mtry=2
## + Fold3: mtry=3
## - Fold3: mtry=3
## + Fold3: mtry=4
## - Fold3: mtry=4
## + Fold3: mtry=5
## - Fold3: mtry=5
## + Fold4: mtry=2
## - Fold4: mtry=2
## + Fold4: mtry=3
## - Fold4: mtry=3
## + Fold4: mtry=4
## - Fold4: mtry=4
## + Fold4: mtry=5
## - Fold4: mtry=5
## + Fold5: mtry=2
## - Fold5: mtry=2
## + Fold5: mtry=3
## - Fold5: mtry=3
## + Fold5: mtry=4
## - Fold5: mtry=4
## + Fold5: mtry=5
## - Fold5: mtry=5
## Aggregating results
## Selecting tuning parameters
## Fitting mtry = 5 on full training set
print(rf_cv_model)
## Random Forest
##
## 45383 samples
## 7 predictor
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 36307, 36307, 36307, 36305, 36306
## Resampling results across tuning parameters:
##
## mtry RMSE Rsquared MAE
## 2 2.732629 0.6886011 2.074324
## 3 2.640034 0.7014165 1.988368
## 4 2.599337 0.7065789 1.947884
## 5 2.579000 0.7088754 1.927013
##
## RMSE was used to select the optimal model using the smallest value.
## The final value used for the model was mtry = 5.
cat("Best Number of Predictors (mtry):", rf_cv_model$bestTune$mtry, "\n")
## Best Number of Predictors (mtry): 5
Variable Importance for the cross-validated Random Forest model
cat("Variable Importance:\n")
## Variable Importance:
print(varImp(rf_cv_model))
## rf variable importance
##
## Overall
## dropoff_latitude 100.00
## pickup_latitude 99.07
## dropoff_longitude 48.14
## pickup_longitude 47.47
## feat02 14.03
## feat09 12.15
## feat06 0.00
# Plot Variable Importance
varImpPlot(rf_cv_model$finalModel, main = "Random Forest Variable Importance")
##Gradient boosting
Step 1: Train the Initial Gradient Boosting Model
# Converting the training and test datasets to matrices
train_matrix <- as.matrix(taxi_train[, -which(names(taxi_train) == "fare_amount")])
train_target <- taxi_train$fare_amount
test_matrix <- as.matrix(taxi_test[, -which(names(taxi_test) == "fare_amount")])
test_target <- taxi_test$fare_amount
# Training an initial Gradient Boosting model
xgb_model <- xgboost(
data = train_matrix,
label = train_target,
nrounds = 100, # Initial number of boosting rounds
objective = "reg:squarederror",
eta = 0.1,
max_depth = 6,
subsample = 0.8,
colsample_bytree = 0.8,
verbose = 1
)
## [1] train-rmse:18.574426
## [2] train-rmse:16.809701
## [3] train-rmse:15.228126
## [4] train-rmse:13.817555
## [5] train-rmse:12.565182
## [6] train-rmse:11.446683
## [7] train-rmse:10.442392
## [8] train-rmse:9.543904
## [9] train-rmse:8.755858
## [10] train-rmse:8.054102
## [11] train-rmse:7.428042
## [12] train-rmse:6.879079
## [13] train-rmse:6.400431
## [14] train-rmse:5.984437
## [15] train-rmse:5.612395
## [16] train-rmse:5.270177
## [17] train-rmse:4.992094
## [18] train-rmse:4.729667
## [19] train-rmse:4.527056
## [20] train-rmse:4.349804
## [21] train-rmse:4.187637
## [22] train-rmse:4.055636
## [23] train-rmse:3.940262
## [24] train-rmse:3.842050
## [25] train-rmse:3.730975
## [26] train-rmse:3.665071
## [27] train-rmse:3.591716
## [28] train-rmse:3.475803
## [29] train-rmse:3.408748
## [30] train-rmse:3.354571
## [31] train-rmse:3.302040
## [32] train-rmse:3.239701
## [33] train-rmse:3.181370
## [34] train-rmse:3.129685
## [35] train-rmse:3.087626
## [36] train-rmse:3.075464
## [37] train-rmse:3.035011
## [38] train-rmse:3.015280
## [39] train-rmse:3.009052
## [40] train-rmse:2.969301
## [41] train-rmse:2.962032
## [42] train-rmse:2.933333
## [43] train-rmse:2.904233
## [44] train-rmse:2.899790
## [45] train-rmse:2.871734
## [46] train-rmse:2.862267
## [47] train-rmse:2.833531
## [48] train-rmse:2.827475
## [49] train-rmse:2.824945
## [50] train-rmse:2.809423
## [51] train-rmse:2.795651
## [52] train-rmse:2.776240
## [53] train-rmse:2.761669
## [54] train-rmse:2.748485
## [55] train-rmse:2.721751
## [56] train-rmse:2.706138
## [57] train-rmse:2.699262
## [58] train-rmse:2.686478
## [59] train-rmse:2.678362
## [60] train-rmse:2.667405
## [61] train-rmse:2.664332
## [62] train-rmse:2.655928
## [63] train-rmse:2.654785
## [64] train-rmse:2.652820
## [65] train-rmse:2.647158
## [66] train-rmse:2.645696
## [67] train-rmse:2.620321
## [68] train-rmse:2.612700
## [69] train-rmse:2.610180
## [70] train-rmse:2.608885
## [71] train-rmse:2.605964
## [72] train-rmse:2.603919
## [73] train-rmse:2.600713
## [74] train-rmse:2.599409
## [75] train-rmse:2.581341
## [76] train-rmse:2.580165
## [77] train-rmse:2.554154
## [78] train-rmse:2.550883
## [79] train-rmse:2.549542
## [80] train-rmse:2.547383
## [81] train-rmse:2.545644
## [82] train-rmse:2.544823
## [83] train-rmse:2.538409
## [84] train-rmse:2.536610
## [85] train-rmse:2.536070
## [86] train-rmse:2.534720
## [87] train-rmse:2.532114
## [88] train-rmse:2.531018
## [89] train-rmse:2.529542
## [90] train-rmse:2.526840
## [91] train-rmse:2.519591
## [92] train-rmse:2.518211
## [93] train-rmse:2.513946
## [94] train-rmse:2.512380
## [95] train-rmse:2.508362
## [96] train-rmse:2.482425
## [97] train-rmse:2.481286
## [98] train-rmse:2.479972
## [99] train-rmse:2.478173
## [100] train-rmse:2.477129
# Making predictions
train_preds <- predict(xgb_model, newdata = train_matrix)
test_preds <- predict(xgb_model, newdata = test_matrix)
# Evaluating model performance
library(Metrics)
# RMSE and MAE
train_rmse <- rmse(train_target, train_preds)
test_rmse <- rmse(test_target, test_preds)
train_mae <- mae(train_target, train_preds)
test_mae <- mae(test_target, test_preds)
# R-squared
train_r2 <- cor(train_target, train_preds)^2
test_r2 <- cor(test_target, test_preds)^2
# Displaying the results
cat("Training Performance:\n")
## Training Performance:
cat("RMSE:", train_rmse, "\n")
## RMSE: 2.477129
cat("MAE:", train_mae, "\n")
## MAE: 1.862077
cat("R-squared:", train_r2, "\n")
## R-squared: 0.7299044
cat("\nTesting Performance:\n")
##
## Testing Performance:
cat("RMSE:", test_rmse, "\n")
## RMSE: 2.617148
cat("MAE:", test_mae, "\n")
## MAE: 1.971123
cat("R-squared:", test_r2, "\n")
## R-squared: 0.6917485
# Combining metrics into a single table
performance_table <- data.frame(
Metric = c("RMSE", "MAE", "R-squared"),
Training = c(train_rmse, train_mae, train_r2),
Testing = c(test_rmse, test_mae, test_r2)
)
# Printing the combined table
print("Model Performance:")
## [1] "Model Performance:"
print(performance_table)
## Metric Training Testing
## 1 RMSE 2.4771293 2.6171484
## 2 MAE 1.8620774 1.9711227
## 3 R-squared 0.7299044 0.6917485
kable(performance_table, caption = "Training and Testing Performance Metrics")
| Metric | Training | Testing |
|---|---|---|
| RMSE | 2.4771293 | 2.6171484 |
| MAE | 1.8620774 | 1.9711227 |
| R-squared | 0.7299044 | 0.6917485 |
# Feature Importance
importance_matrix <- xgb.importance(model = xgb_model)
print("Feature Importance:")
## [1] "Feature Importance:"
print(importance_matrix)
## Feature Gain Cover Frequency
## <char> <num> <num> <num>
## 1: pickup_latitude 0.283760458 0.19205992 0.19427516
## 2: dropoff_latitude 0.271884783 0.21273324 0.21089566
## 3: dropoff_longitude 0.184469740 0.17547106 0.19335180
## 4: pickup_longitude 0.167276434 0.14228400 0.16177285
## 5: feat02 0.049021665 0.12732690 0.09529086
## 6: feat09 0.039882048 0.11199921 0.08254848
## 7: feat06 0.003704872 0.03812567 0.06186519
xgb.plot.importance(
importance_matrix,
rel_to_first = TRUE,
top_n = 10,
main = "Feature Importance (XGBoost)"
)
Used Cross-Validation to Find Optimal nrounds. This step helps determine the ideal number of boosting rounds to prevent overfitting.
# Perform cross-validation
cv_results <- xgb.cv(
data = train_matrix,
label = train_target,
nfold = 5,
nrounds = 300,
objective = "reg:squarederror",
eta = 0.1,
max_depth = 6,
subsample = 0.8,
colsample_bytree = 0.8,
verbose = 1,
early_stopping_rounds = 10 # Stop early if no improvement
)
## [1] train-rmse:18.567339+0.007496 test-rmse:18.570053+0.029179
## Multiple eval metrics are present. Will use test_rmse for early stopping.
## Will train until test_rmse hasn't improved in 10 rounds.
##
## [2] train-rmse:16.811445+0.009895 test-rmse:16.816567+0.026865
## [3] train-rmse:15.238163+0.012633 test-rmse:15.245205+0.024645
## [4] train-rmse:13.831930+0.018499 test-rmse:13.841370+0.026352
## [5] train-rmse:12.571182+0.015409 test-rmse:12.583494+0.023691
## [6] train-rmse:11.441930+0.016591 test-rmse:11.456992+0.030346
## [7] train-rmse:10.435702+0.017477 test-rmse:10.452708+0.021339
## [8] train-rmse:9.541816+0.020163 test-rmse:9.561214+0.020089
## [9] train-rmse:8.750467+0.021237 test-rmse:8.772148+0.021722
## [10] train-rmse:8.049487+0.025981 test-rmse:8.074985+0.025434
## [11] train-rmse:7.428701+0.018742 test-rmse:7.459265+0.021088
## [12] train-rmse:6.886161+0.021955 test-rmse:6.919866+0.023184
## [13] train-rmse:6.407931+0.018947 test-rmse:6.444886+0.022687
## [14] train-rmse:5.986369+0.025269 test-rmse:6.026761+0.038886
## [15] train-rmse:5.614238+0.041295 test-rmse:5.657907+0.056524
## [16] train-rmse:5.287115+0.048568 test-rmse:5.335982+0.065838
## [17] train-rmse:5.006454+0.051749 test-rmse:5.059849+0.068943
## [18] train-rmse:4.756266+0.059175 test-rmse:4.812189+0.077370
## [19] train-rmse:4.529476+0.060246 test-rmse:4.589898+0.074925
## [20] train-rmse:4.340401+0.056393 test-rmse:4.403774+0.070268
## [21] train-rmse:4.193157+0.054169 test-rmse:4.259346+0.065356
## [22] train-rmse:4.058629+0.058839 test-rmse:4.128919+0.072115
## [23] train-rmse:3.943047+0.061080 test-rmse:4.016358+0.070954
## [24] train-rmse:3.836419+0.071643 test-rmse:3.914467+0.078612
## [25] train-rmse:3.729900+0.067302 test-rmse:3.810635+0.071737
## [26] train-rmse:3.644203+0.062048 test-rmse:3.727045+0.064550
## [27] train-rmse:3.572742+0.077149 test-rmse:3.658653+0.079413
## [28] train-rmse:3.507474+0.073363 test-rmse:3.595457+0.074415
## [29] train-rmse:3.438667+0.068138 test-rmse:3.529862+0.066384
## [30] train-rmse:3.368544+0.064450 test-rmse:3.462030+0.062989
## [31] train-rmse:3.322592+0.047356 test-rmse:3.417903+0.048470
## [32] train-rmse:3.268961+0.045969 test-rmse:3.365941+0.043086
## [33] train-rmse:3.226572+0.047047 test-rmse:3.324909+0.045502
## [34] train-rmse:3.191433+0.050063 test-rmse:3.291983+0.048549
## [35] train-rmse:3.155592+0.042159 test-rmse:3.256872+0.038018
## [36] train-rmse:3.121371+0.035003 test-rmse:3.224500+0.034114
## [37] train-rmse:3.081695+0.031362 test-rmse:3.187426+0.024639
## [38] train-rmse:3.048816+0.030343 test-rmse:3.156891+0.030966
## [39] train-rmse:3.020182+0.029322 test-rmse:3.130501+0.024005
## [40] train-rmse:2.995083+0.033634 test-rmse:3.106533+0.028038
## [41] train-rmse:2.973613+0.029584 test-rmse:3.086079+0.030840
## [42] train-rmse:2.951017+0.043089 test-rmse:3.065450+0.043013
## [43] train-rmse:2.931011+0.043599 test-rmse:3.046792+0.047999
## [44] train-rmse:2.912720+0.040683 test-rmse:3.030165+0.040684
## [45] train-rmse:2.892319+0.044488 test-rmse:3.011934+0.037679
## [46] train-rmse:2.871778+0.041584 test-rmse:2.992795+0.034192
## [47] train-rmse:2.846210+0.030767 test-rmse:2.968614+0.029375
## [48] train-rmse:2.829184+0.034040 test-rmse:2.953463+0.030524
## [49] train-rmse:2.810296+0.035913 test-rmse:2.935940+0.030692
## [50] train-rmse:2.793430+0.029605 test-rmse:2.920313+0.033010
## [51] train-rmse:2.772944+0.028133 test-rmse:2.901820+0.037364
## [52] train-rmse:2.754525+0.029714 test-rmse:2.885434+0.032492
## [53] train-rmse:2.741401+0.023085 test-rmse:2.873804+0.037554
## [54] train-rmse:2.723369+0.016548 test-rmse:2.858818+0.030428
## [55] train-rmse:2.709189+0.021185 test-rmse:2.846779+0.031919
## [56] train-rmse:2.705111+0.019770 test-rmse:2.844142+0.032552
## [57] train-rmse:2.695483+0.015985 test-rmse:2.835988+0.030574
## [58] train-rmse:2.683056+0.014555 test-rmse:2.825376+0.024487
## [59] train-rmse:2.675564+0.018017 test-rmse:2.819474+0.024485
## [60] train-rmse:2.664403+0.012222 test-rmse:2.810200+0.027104
## [61] train-rmse:2.658429+0.013384 test-rmse:2.805981+0.027495
## [62] train-rmse:2.649419+0.013807 test-rmse:2.798965+0.030831
## [63] train-rmse:2.636716+0.012669 test-rmse:2.787796+0.032012
## [64] train-rmse:2.633043+0.013372 test-rmse:2.785330+0.033418
## [65] train-rmse:2.626432+0.014974 test-rmse:2.780280+0.032856
## [66] train-rmse:2.620637+0.014644 test-rmse:2.776326+0.034551
## [67] train-rmse:2.617936+0.015115 test-rmse:2.775027+0.035486
## [68] train-rmse:2.613032+0.012747 test-rmse:2.771760+0.031900
## [69] train-rmse:2.607928+0.013006 test-rmse:2.767799+0.033096
## [70] train-rmse:2.599083+0.014423 test-rmse:2.760718+0.035761
## [71] train-rmse:2.591321+0.017744 test-rmse:2.754422+0.040641
## [72] train-rmse:2.586726+0.017745 test-rmse:2.751465+0.040528
## [73] train-rmse:2.575251+0.020092 test-rmse:2.741746+0.041170
## [74] train-rmse:2.565546+0.021274 test-rmse:2.734104+0.039767
## [75] train-rmse:2.562801+0.020836 test-rmse:2.732556+0.039411
## [76] train-rmse:2.558215+0.019279 test-rmse:2.729160+0.040585
## [77] train-rmse:2.555307+0.017617 test-rmse:2.727786+0.040151
## [78] train-rmse:2.547446+0.018843 test-rmse:2.721690+0.039822
## [79] train-rmse:2.542203+0.018687 test-rmse:2.717911+0.041310
## [80] train-rmse:2.536716+0.016586 test-rmse:2.714052+0.044077
## [81] train-rmse:2.531196+0.020612 test-rmse:2.709806+0.046073
## [82] train-rmse:2.524123+0.022088 test-rmse:2.704766+0.043840
## [83] train-rmse:2.521600+0.022569 test-rmse:2.703831+0.044619
## [84] train-rmse:2.514604+0.027766 test-rmse:2.698596+0.049611
## [85] train-rmse:2.508624+0.027454 test-rmse:2.694660+0.049370
## [86] train-rmse:2.507333+0.027469 test-rmse:2.694838+0.049446
## [87] train-rmse:2.505481+0.027576 test-rmse:2.694359+0.049177
## [88] train-rmse:2.501021+0.026718 test-rmse:2.691769+0.047436
## [89] train-rmse:2.499351+0.026872 test-rmse:2.691186+0.047479
## [90] train-rmse:2.490671+0.026433 test-rmse:2.684548+0.044183
## [91] train-rmse:2.484036+0.024990 test-rmse:2.678901+0.043079
## [92] train-rmse:2.481829+0.025037 test-rmse:2.678105+0.043562
## [93] train-rmse:2.479970+0.025439 test-rmse:2.677747+0.044136
## [94] train-rmse:2.476755+0.024184 test-rmse:2.675718+0.045595
## [95] train-rmse:2.472286+0.028045 test-rmse:2.672627+0.046078
## [96] train-rmse:2.464533+0.026483 test-rmse:2.665913+0.047865
## [97] train-rmse:2.461341+0.024504 test-rmse:2.664175+0.045527
## [98] train-rmse:2.458848+0.024286 test-rmse:2.662822+0.045186
## [99] train-rmse:2.457393+0.024305 test-rmse:2.662711+0.045159
## [100] train-rmse:2.454487+0.023088 test-rmse:2.661115+0.044669
## [101] train-rmse:2.451294+0.022532 test-rmse:2.659232+0.043199
## [102] train-rmse:2.449322+0.022657 test-rmse:2.658724+0.042980
## [103] train-rmse:2.448083+0.022439 test-rmse:2.658586+0.042569
## [104] train-rmse:2.444482+0.022680 test-rmse:2.656299+0.042356
## [105] train-rmse:2.441512+0.020573 test-rmse:2.654599+0.039539
## [106] train-rmse:2.440054+0.019918 test-rmse:2.654292+0.039092
## [107] train-rmse:2.436173+0.017602 test-rmse:2.651916+0.035980
## [108] train-rmse:2.432175+0.020655 test-rmse:2.649120+0.039148
## [109] train-rmse:2.424157+0.024546 test-rmse:2.642732+0.042360
## [110] train-rmse:2.419737+0.022238 test-rmse:2.639316+0.038158
## [111] train-rmse:2.418173+0.022207 test-rmse:2.638837+0.038406
## [112] train-rmse:2.414366+0.022351 test-rmse:2.636254+0.035419
## [113] train-rmse:2.410194+0.021799 test-rmse:2.633446+0.037162
## [114] train-rmse:2.408192+0.022577 test-rmse:2.632737+0.037043
## [115] train-rmse:2.405642+0.022381 test-rmse:2.631923+0.037141
## [116] train-rmse:2.402856+0.023606 test-rmse:2.630712+0.037455
## [117] train-rmse:2.400919+0.024302 test-rmse:2.629939+0.038103
## [118] train-rmse:2.398378+0.024288 test-rmse:2.628274+0.038007
## [119] train-rmse:2.397007+0.025045 test-rmse:2.627883+0.038425
## [120] train-rmse:2.393828+0.027224 test-rmse:2.626083+0.039072
## [121] train-rmse:2.392681+0.027105 test-rmse:2.625975+0.039177
## [122] train-rmse:2.385558+0.024958 test-rmse:2.620598+0.036404
## [123] train-rmse:2.384162+0.025220 test-rmse:2.620156+0.036338
## [124] train-rmse:2.381026+0.027001 test-rmse:2.618883+0.036257
## [125] train-rmse:2.377221+0.025960 test-rmse:2.616468+0.033865
## [126] train-rmse:2.372789+0.023669 test-rmse:2.613251+0.034219
## [127] train-rmse:2.369247+0.022745 test-rmse:2.610921+0.036180
## [128] train-rmse:2.367565+0.023461 test-rmse:2.610150+0.035127
## [129] train-rmse:2.366401+0.023309 test-rmse:2.610174+0.035222
## [130] train-rmse:2.363117+0.022771 test-rmse:2.608221+0.036408
## [131] train-rmse:2.359665+0.021250 test-rmse:2.605939+0.034543
## [132] train-rmse:2.357519+0.020625 test-rmse:2.604932+0.034951
## [133] train-rmse:2.355797+0.021191 test-rmse:2.604387+0.035143
## [134] train-rmse:2.349629+0.021012 test-rmse:2.599806+0.037481
## [135] train-rmse:2.347115+0.021285 test-rmse:2.598161+0.037135
## [136] train-rmse:2.345417+0.021550 test-rmse:2.597775+0.037451
## [137] train-rmse:2.343246+0.021258 test-rmse:2.596775+0.038678
## [138] train-rmse:2.339994+0.019895 test-rmse:2.594986+0.037890
## [139] train-rmse:2.335515+0.014871 test-rmse:2.591667+0.034948
## [140] train-rmse:2.332542+0.014494 test-rmse:2.590166+0.036664
## [141] train-rmse:2.331251+0.014679 test-rmse:2.589812+0.036814
## [142] train-rmse:2.328323+0.015471 test-rmse:2.588277+0.036037
## [143] train-rmse:2.326609+0.015787 test-rmse:2.587680+0.035732
## [144] train-rmse:2.325477+0.015754 test-rmse:2.587528+0.035667
## [145] train-rmse:2.324056+0.015593 test-rmse:2.587361+0.035421
## [146] train-rmse:2.321199+0.016641 test-rmse:2.585513+0.033396
## [147] train-rmse:2.318323+0.017134 test-rmse:2.584279+0.034336
## [148] train-rmse:2.313345+0.019574 test-rmse:2.580571+0.031037
## [149] train-rmse:2.311598+0.019849 test-rmse:2.580476+0.030896
## [150] train-rmse:2.310271+0.019855 test-rmse:2.579963+0.031417
## [151] train-rmse:2.307320+0.017250 test-rmse:2.578207+0.029601
## [152] train-rmse:2.306209+0.017185 test-rmse:2.578219+0.029708
## [153] train-rmse:2.304424+0.016628 test-rmse:2.577346+0.029970
## [154] train-rmse:2.303147+0.015996 test-rmse:2.576888+0.029558
## [155] train-rmse:2.302032+0.015806 test-rmse:2.576833+0.029533
## [156] train-rmse:2.298599+0.015485 test-rmse:2.574603+0.031748
## [157] train-rmse:2.294838+0.015652 test-rmse:2.572276+0.030608
## [158] train-rmse:2.289950+0.015607 test-rmse:2.569096+0.027725
## [159] train-rmse:2.288737+0.015336 test-rmse:2.569124+0.028032
## [160] train-rmse:2.287314+0.014980 test-rmse:2.568885+0.028288
## [161] train-rmse:2.283433+0.018421 test-rmse:2.566294+0.027507
## [162] train-rmse:2.281656+0.018243 test-rmse:2.565232+0.028002
## [163] train-rmse:2.279766+0.018764 test-rmse:2.564588+0.027434
## [164] train-rmse:2.276258+0.020383 test-rmse:2.562532+0.025109
## [165] train-rmse:2.274434+0.019544 test-rmse:2.562052+0.025859
## [166] train-rmse:2.270844+0.019772 test-rmse:2.559150+0.025334
## [167] train-rmse:2.267428+0.018768 test-rmse:2.557072+0.027673
## [168] train-rmse:2.266054+0.018523 test-rmse:2.556695+0.027820
## [169] train-rmse:2.263739+0.017139 test-rmse:2.555695+0.029343
## [170] train-rmse:2.260955+0.018156 test-rmse:2.553993+0.028139
## [171] train-rmse:2.258762+0.019837 test-rmse:2.552931+0.027289
## [172] train-rmse:2.257029+0.019620 test-rmse:2.552476+0.027325
## [173] train-rmse:2.255922+0.019563 test-rmse:2.552363+0.027441
## [174] train-rmse:2.254527+0.019672 test-rmse:2.552127+0.027757
## [175] train-rmse:2.250441+0.020231 test-rmse:2.549579+0.027834
## [176] train-rmse:2.247775+0.021454 test-rmse:2.548161+0.027227
## [177] train-rmse:2.244899+0.020343 test-rmse:2.546482+0.027296
## [178] train-rmse:2.243535+0.020160 test-rmse:2.546150+0.026948
## [179] train-rmse:2.241248+0.021130 test-rmse:2.545111+0.026369
## [180] train-rmse:2.238492+0.019299 test-rmse:2.543439+0.025807
## [181] train-rmse:2.236971+0.020161 test-rmse:2.542766+0.025148
## [182] train-rmse:2.232990+0.022813 test-rmse:2.540242+0.024314
## [183] train-rmse:2.231693+0.022603 test-rmse:2.540204+0.024428
## [184] train-rmse:2.228226+0.025287 test-rmse:2.538292+0.022804
## [185] train-rmse:2.226594+0.026049 test-rmse:2.537815+0.022713
## [186] train-rmse:2.223345+0.024784 test-rmse:2.535722+0.024689
## [187] train-rmse:2.219438+0.022833 test-rmse:2.533024+0.028555
## [188] train-rmse:2.215572+0.025546 test-rmse:2.530395+0.026931
## [189] train-rmse:2.214470+0.025465 test-rmse:2.530392+0.026954
## [190] train-rmse:2.213406+0.025339 test-rmse:2.530354+0.027008
## [191] train-rmse:2.210896+0.024369 test-rmse:2.529397+0.026147
## [192] train-rmse:2.209048+0.023862 test-rmse:2.528530+0.026856
## [193] train-rmse:2.206189+0.025556 test-rmse:2.527089+0.026426
## [194] train-rmse:2.204850+0.025794 test-rmse:2.526925+0.026215
## [195] train-rmse:2.203636+0.025618 test-rmse:2.526832+0.026306
## [196] train-rmse:2.201531+0.024759 test-rmse:2.526368+0.027157
## [197] train-rmse:2.200234+0.024552 test-rmse:2.525960+0.026938
## [198] train-rmse:2.196829+0.022292 test-rmse:2.524090+0.025091
## [199] train-rmse:2.195635+0.022312 test-rmse:2.523972+0.025446
## [200] train-rmse:2.192146+0.022188 test-rmse:2.521761+0.026351
## [201] train-rmse:2.189766+0.022502 test-rmse:2.520353+0.026446
## [202] train-rmse:2.187821+0.023054 test-rmse:2.519721+0.025608
## [203] train-rmse:2.186843+0.023082 test-rmse:2.519844+0.025534
## [204] train-rmse:2.185348+0.023628 test-rmse:2.519366+0.025584
## [205] train-rmse:2.183302+0.023650 test-rmse:2.518653+0.025887
## [206] train-rmse:2.182528+0.023631 test-rmse:2.518648+0.025858
## [207] train-rmse:2.181245+0.024118 test-rmse:2.518199+0.025873
## [208] train-rmse:2.179558+0.024047 test-rmse:2.517896+0.025992
## [209] train-rmse:2.178527+0.024121 test-rmse:2.517630+0.026220
## [210] train-rmse:2.177598+0.024097 test-rmse:2.517598+0.026190
## [211] train-rmse:2.174735+0.022549 test-rmse:2.516211+0.024694
## [212] train-rmse:2.172970+0.022084 test-rmse:2.514942+0.024373
## [213] train-rmse:2.170619+0.022241 test-rmse:2.513952+0.024516
## [214] train-rmse:2.168728+0.022523 test-rmse:2.513221+0.024850
## [215] train-rmse:2.166192+0.022682 test-rmse:2.512219+0.025488
## [216] train-rmse:2.164387+0.022654 test-rmse:2.511586+0.026303
## [217] train-rmse:2.162219+0.023794 test-rmse:2.510681+0.025701
## [218] train-rmse:2.158723+0.024972 test-rmse:2.508538+0.025473
## [219] train-rmse:2.156934+0.025057 test-rmse:2.507972+0.026181
## [220] train-rmse:2.154885+0.026204 test-rmse:2.506937+0.025590
## [221] train-rmse:2.153652+0.026494 test-rmse:2.506754+0.025586
## [222] train-rmse:2.151500+0.027526 test-rmse:2.505734+0.025661
## [223] train-rmse:2.150145+0.027393 test-rmse:2.505529+0.025911
## [224] train-rmse:2.147815+0.028929 test-rmse:2.504612+0.025447
## [225] train-rmse:2.145587+0.028607 test-rmse:2.504261+0.025720
## [226] train-rmse:2.144265+0.028820 test-rmse:2.503970+0.025509
## [227] train-rmse:2.142978+0.028346 test-rmse:2.503673+0.024976
## [228] train-rmse:2.139614+0.025450 test-rmse:2.501276+0.026900
## [229] train-rmse:2.138520+0.025619 test-rmse:2.501075+0.026837
## [230] train-rmse:2.136383+0.024352 test-rmse:2.500225+0.025955
## [231] train-rmse:2.135299+0.024464 test-rmse:2.500105+0.025816
## [232] train-rmse:2.133459+0.025258 test-rmse:2.499828+0.025631
## [233] train-rmse:2.132118+0.024587 test-rmse:2.499327+0.025476
## [234] train-rmse:2.131334+0.024543 test-rmse:2.499327+0.025415
## [235] train-rmse:2.129628+0.023871 test-rmse:2.498891+0.025831
## [236] train-rmse:2.127479+0.023337 test-rmse:2.498317+0.026373
## [237] train-rmse:2.125706+0.023333 test-rmse:2.497652+0.025943
## [238] train-rmse:2.124767+0.023398 test-rmse:2.497670+0.025865
## [239] train-rmse:2.122247+0.021611 test-rmse:2.496370+0.026768
## [240] train-rmse:2.119800+0.022820 test-rmse:2.495420+0.025708
## [241] train-rmse:2.118744+0.022548 test-rmse:2.495103+0.025454
## [242] train-rmse:2.117637+0.022459 test-rmse:2.494885+0.025278
## [243] train-rmse:2.116652+0.022276 test-rmse:2.494853+0.025264
## [244] train-rmse:2.115621+0.022392 test-rmse:2.494585+0.025357
## [245] train-rmse:2.114540+0.022585 test-rmse:2.494426+0.025303
## [246] train-rmse:2.111325+0.020742 test-rmse:2.492833+0.026119
## [247] train-rmse:2.109454+0.021018 test-rmse:2.492596+0.026224
## [248] train-rmse:2.107871+0.021241 test-rmse:2.492186+0.025979
## [249] train-rmse:2.106523+0.021805 test-rmse:2.492039+0.025720
## [250] train-rmse:2.105493+0.022070 test-rmse:2.492106+0.025555
## [251] train-rmse:2.104020+0.022044 test-rmse:2.491712+0.025712
## [252] train-rmse:2.101961+0.021698 test-rmse:2.491047+0.025668
## [253] train-rmse:2.100192+0.021308 test-rmse:2.490507+0.026185
## [254] train-rmse:2.098186+0.020407 test-rmse:2.489775+0.026182
## [255] train-rmse:2.097451+0.020649 test-rmse:2.489694+0.026069
## [256] train-rmse:2.095021+0.021017 test-rmse:2.488316+0.026403
## [257] train-rmse:2.093645+0.020562 test-rmse:2.488180+0.026610
## [258] train-rmse:2.091990+0.020460 test-rmse:2.487692+0.026323
## [259] train-rmse:2.090785+0.020908 test-rmse:2.487637+0.026269
## [260] train-rmse:2.090074+0.020827 test-rmse:2.487496+0.026086
## [261] train-rmse:2.088275+0.020947 test-rmse:2.486786+0.026530
## [262] train-rmse:2.087027+0.021211 test-rmse:2.486777+0.026580
## [263] train-rmse:2.085810+0.021054 test-rmse:2.486765+0.026870
## [264] train-rmse:2.084111+0.021273 test-rmse:2.486390+0.027289
## [265] train-rmse:2.081977+0.021384 test-rmse:2.485589+0.027919
## [266] train-rmse:2.080817+0.021513 test-rmse:2.485273+0.028121
## [267] train-rmse:2.079615+0.021220 test-rmse:2.485103+0.028080
## [268] train-rmse:2.078519+0.021447 test-rmse:2.484896+0.027999
## [269] train-rmse:2.077650+0.021519 test-rmse:2.484694+0.027977
## [270] train-rmse:2.076541+0.021535 test-rmse:2.484355+0.028120
## [271] train-rmse:2.075681+0.021444 test-rmse:2.484400+0.028127
## [272] train-rmse:2.074584+0.021234 test-rmse:2.484488+0.028199
## [273] train-rmse:2.073712+0.021288 test-rmse:2.484399+0.028175
## [274] train-rmse:2.072761+0.021272 test-rmse:2.484426+0.028235
## [275] train-rmse:2.071673+0.021505 test-rmse:2.484460+0.028251
## [276] train-rmse:2.070815+0.021553 test-rmse:2.484480+0.028165
## [277] train-rmse:2.069467+0.021281 test-rmse:2.484110+0.028569
## [278] train-rmse:2.068371+0.021003 test-rmse:2.483871+0.028523
## [279] train-rmse:2.066694+0.020690 test-rmse:2.483407+0.028292
## [280] train-rmse:2.065405+0.021364 test-rmse:2.483082+0.027929
## [281] train-rmse:2.064625+0.021361 test-rmse:2.483082+0.027941
## [282] train-rmse:2.062464+0.022404 test-rmse:2.482477+0.027423
## [283] train-rmse:2.061343+0.022403 test-rmse:2.482245+0.027450
## [284] train-rmse:2.060363+0.022574 test-rmse:2.482096+0.027540
## [285] train-rmse:2.059054+0.022204 test-rmse:2.481595+0.027344
## [286] train-rmse:2.057954+0.022477 test-rmse:2.481456+0.027223
## [287] train-rmse:2.056401+0.021660 test-rmse:2.480728+0.026844
## [288] train-rmse:2.055567+0.021908 test-rmse:2.480809+0.026920
## [289] train-rmse:2.053356+0.022051 test-rmse:2.479985+0.027500
## [290] train-rmse:2.051926+0.021895 test-rmse:2.479620+0.027713
## [291] train-rmse:2.048455+0.020269 test-rmse:2.477550+0.029062
## [292] train-rmse:2.046787+0.019386 test-rmse:2.476842+0.028684
## [293] train-rmse:2.046010+0.019529 test-rmse:2.476824+0.028661
## [294] train-rmse:2.044658+0.019738 test-rmse:2.476672+0.028646
## [295] train-rmse:2.043353+0.019683 test-rmse:2.476740+0.028642
## [296] train-rmse:2.042294+0.019897 test-rmse:2.476613+0.028591
## [297] train-rmse:2.041190+0.020111 test-rmse:2.476559+0.028507
## [298] train-rmse:2.040388+0.020262 test-rmse:2.476538+0.028595
## [299] train-rmse:2.038494+0.019192 test-rmse:2.475837+0.028732
## [300] train-rmse:2.036850+0.018846 test-rmse:2.475246+0.029273
# Save the cross-validation results to a CSV
write.csv(cv_results$evaluation_log, "xgb_cv_results.csv", row.names = FALSE)
cv_results_saved <- read.csv("xgb_cv_results.csv")
print(head(cv_results_saved))
## iter train_rmse_mean train_rmse_std test_rmse_mean test_rmse_std
## 1 1 18.56734 0.007495757 18.57005 0.02917894
## 2 2 16.81145 0.009894786 16.81657 0.02686496
## 3 3 15.23816 0.012632769 15.24520 0.02464496
## 4 4 13.83193 0.018498680 13.84137 0.02635198
## 5 5 12.57118 0.015408893 12.58349 0.02369056
## 6 6 11.44193 0.016591170 11.45699 0.03034644
RMSE starts at a high value 18.56 and decreases consistently with each boosting round, eventually reaching 2.49 at the 100th round. This is a good sign that the model is learning effectively and not overfitting during training.
Identifying the Optimal Number of Rounds
# Extract the optimal number of rounds
optimal_nrounds <- cv_results$best_iteration
cat("Optimal number of boosting rounds:", optimal_nrounds, "\n")
## Optimal number of boosting rounds: 300
Retrained the model using the optimal number of rounds to ensure it’s trained with the best configuration based on cross-validation.
final_xgb_model <- xgboost(
data = train_matrix,
label = train_target,
nrounds = 299, # Use the optimal number of boosting rounds
objective = "reg:squarederror",
eta = 0.1,
max_depth = 6,
subsample = 0.8,
colsample_bytree = 0.8,
verbose = 0
)
Evaluating the model
xgb_train_preds <- predict(final_xgb_model, newdata = train_matrix)
xgb_test_preds <- predict(final_xgb_model, newdata = test_matrix)
train_metrics <- data.frame(
RMSE = RMSE(xgb_train_preds, train_target),
MAE = MAE(xgb_train_preds, train_target),
R2 = cor(xgb_train_preds, train_target)^2
)
test_metrics <- data.frame(
RMSE = RMSE(xgb_test_preds, test_target),
MAE = MAE(xgb_test_preds, test_target),
R2 = cor(xgb_test_preds, test_target)^2
)
# Print metrics
print("Training Performance:")
## [1] "Training Performance:"
print(train_metrics)
## RMSE MAE R2
## 1 2.097662 1.546659 0.8040816
print("Testing Performance:")
## [1] "Testing Performance:"
print(test_metrics)
## RMSE MAE R2
## 1 2.433837 1.787104 0.7297126
train_metrics$Set <- "Training"
test_metrics$Set <- "Testing"
combined_metrics <- rbind(train_metrics, test_metrics)
combined_metrics <- combined_metrics[, c("Set", "RMSE", "MAE", "R2")]
print("Combined Performance Metrics:")
## [1] "Combined Performance Metrics:"
print(combined_metrics)
## Set RMSE MAE R2
## 1 Training 2.097662 1.546659 0.8040816
## 2 Testing 2.433837 1.787104 0.7297126
Feature Impontance
importance <- xgb.importance(feature_names = colnames(train_matrix), model = final_xgb_model)
# Print feature importance
print(importance)
## Feature Gain Cover Frequency
## <char> <num> <num> <num>
## 1: dropoff_latitude 0.31618583 0.17030766 0.20824742
## 2: pickup_latitude 0.26366246 0.14566413 0.16115730
## 3: dropoff_longitude 0.16218458 0.14798619 0.16960426
## 4: pickup_longitude 0.15286939 0.12887895 0.13874293
## 5: feat02 0.05063177 0.17642723 0.11739275
## 6: feat09 0.04361945 0.15333546 0.10701696
## 7: feat06 0.01084652 0.07740038 0.09783838
# Plot feature importance
xgb.plot.importance(importance, main = "Feature Importance (XGBoost)")
##Neural Networks
Train a Neural Network model
# Train a Neural Network model
nn_model <- nnet(
fare_amount ~ .,
data = taxi_train,
size = 10,
linout = TRUE,
maxit = 1000,
decay = 0.01
)
## # weights: 91
## initial value 18812949.732032
## iter 10 value 872529.622480
## iter 20 value 623952.196043
## iter 30 value 459014.409673
## iter 40 value 444008.198190
## iter 50 value 434998.335363
## iter 60 value 421197.263762
## iter 70 value 407899.938442
## iter 80 value 391883.956686
## iter 90 value 374358.561046
## iter 100 value 344788.762200
## iter 110 value 309944.585682
## iter 120 value 292610.009442
## iter 130 value 280286.205414
## iter 140 value 276342.057249
## iter 150 value 274225.571156
## iter 160 value 273082.115944
## iter 170 value 272537.823340
## iter 180 value 271886.644356
## iter 190 value 271728.480786
## iter 200 value 271674.784278
## iter 210 value 271647.438909
## iter 220 value 271543.178571
## iter 230 value 271467.159108
## iter 240 value 271159.608346
## iter 250 value 270604.909284
## iter 260 value 270044.587473
## iter 270 value 269134.414776
## iter 280 value 268444.223709
## iter 290 value 267847.182251
## iter 300 value 267406.652354
## iter 310 value 265809.693904
## iter 320 value 265035.976227
## iter 330 value 264307.257612
## iter 340 value 263741.397973
## iter 350 value 263152.712739
## iter 360 value 262872.434935
## iter 370 value 262494.478165
## iter 380 value 262426.838080
## iter 390 value 262294.502479
## iter 400 value 262208.456947
## iter 410 value 262154.137444
## iter 420 value 262110.786612
## iter 430 value 262091.363872
## iter 440 value 262023.850209
## iter 450 value 261927.857728
## iter 460 value 261790.043092
## iter 470 value 261521.933658
## iter 480 value 260998.688455
## iter 490 value 260726.166296
## iter 500 value 260672.063412
## iter 510 value 260625.251491
## iter 520 value 260613.770995
## iter 530 value 260590.251756
## iter 540 value 260564.800520
## iter 550 value 260553.004796
## iter 560 value 260549.383780
## iter 570 value 260543.726248
## iter 580 value 260537.000162
## iter 590 value 260525.275463
## iter 600 value 260511.746080
## iter 610 value 260508.777972
## iter 620 value 260501.533481
## iter 630 value 260477.446995
## iter 640 value 260413.717448
## iter 650 value 260294.341683
## iter 660 value 260178.436190
## iter 670 value 260081.739469
## iter 680 value 259874.600292
## iter 690 value 258742.259501
## iter 700 value 257800.035553
## iter 710 value 255900.807734
## iter 720 value 254369.633423
## iter 730 value 253716.466826
## iter 740 value 253533.348217
## iter 750 value 253416.755194
## iter 760 value 253312.055063
## iter 770 value 253267.668993
## iter 780 value 253142.382091
## iter 790 value 253030.562479
## iter 800 value 252970.040910
## iter 810 value 252875.505893
## iter 820 value 252382.781943
## iter 830 value 252057.733618
## iter 840 value 251941.037117
## iter 850 value 251883.044350
## iter 860 value 251843.609013
## iter 870 value 251819.597636
## iter 880 value 251807.299010
## iter 890 value 251793.322447
## iter 900 value 251783.995851
## iter 910 value 251778.567205
## iter 920 value 251768.893643
## iter 930 value 251764.451200
## iter 940 value 251755.249670
## iter 950 value 251749.819010
## iter 960 value 251745.697598
## iter 970 value 251740.405396
## iter 980 value 251723.092102
## iter 990 value 251709.341060
## iter1000 value 251643.694988
## final value 251643.694988
## stopped after 1000 iterations
nn_train_preds <- predict(nn_model, newdata = taxi_train)
nn_test_preds <- predict(nn_model, newdata = taxi_test)
nn_train_metrics <- data.frame(
RMSE = RMSE(nn_train_preds, taxi_train$fare_amount),
MAE = MAE(nn_train_preds, taxi_train$fare_amount),
R2 = cor(nn_train_preds, taxi_train$fare_amount)^2
)
nn_test_metrics <- data.frame(
RMSE = RMSE(nn_test_preds, taxi_test$fare_amount),
MAE = MAE(nn_test_preds, taxi_test$fare_amount),
R2 = cor(nn_test_preds, taxi_test$fare_amount)^2
)
# Print metrics
print("Training Performance (Neural Network):")
## [1] "Training Performance (Neural Network):"
print(nn_train_metrics)
## RMSE MAE R2
## 1 2.353712 1.684833 0.7497582
print("Testing Performance (Neural Network):")
## [1] "Testing Performance (Neural Network):"
print(nn_test_metrics)
## RMSE MAE R2
## 1 2.330919 1.679057 0.7513461
# Combine metrics into a single table
nn_metrics_table <- data.frame(
Metric = c("RMSE", "MAE", "R-squared"),
Training = c(
RMSE(nn_train_preds, taxi_train$fare_amount),
MAE(nn_train_preds, taxi_train$fare_amount),
cor(nn_train_preds, taxi_train$fare_amount)^2
),
Testing = c(
RMSE(nn_test_preds, taxi_test$fare_amount),
MAE(nn_test_preds, taxi_test$fare_amount),
cor(nn_test_preds, taxi_test$fare_amount)^2
)
)
# Print the combined table
print("Neural Network Model Performance:")
## [1] "Neural Network Model Performance:"
print(nn_metrics_table)
## Metric Training Testing
## 1 RMSE 2.3537122 2.3309187
## 2 MAE 1.6848334 1.6790572
## 3 R-squared 0.7497582 0.7513461
# Optional: Display the table using knitr for better visualization
library(knitr)
kable(nn_metrics_table, caption = "Training and Testing Performance Metrics for Neural Network")
| Metric | Training | Testing |
|---|---|---|
| RMSE | 2.3537122 | 2.3309187 |
| MAE | 1.6848334 | 1.6790572 |
| R-squared | 0.7497582 | 0.7513461 |
Cross-Validation :This method involves manually splitting the data into training and validation folds and averaging the performance metrics across folds. The cross validation results where the same with my predictions henceforth i didnt change the model.
set.seed(12345)
k <- 5
folds <- createFolds(taxi_train$fare_amount, k = k)
cv_results <- list()
for (i in seq_along(folds)) {
train_indices <- unlist(folds[-i])
valid_indices <- unlist(folds[i])
train_data <- taxi_train[train_indices, ]
valid_data <- taxi_train[valid_indices, ]
nn_model <- nnet(
fare_amount ~ .,
data = train_data,
size = 10,
linout = TRUE,
maxit = 500,
decay = 0.01
)
valid_preds <- predict(nn_model, newdata = valid_data)
fold_metrics <- data.frame(
RMSE = RMSE(valid_preds, valid_data$fare_amount),
MAE = MAE(valid_preds, valid_data$fare_amount),
R2 = cor(valid_preds, valid_data$fare_amount)^2
)
cv_results[[i]] <- fold_metrics
}
## # weights: 91
## initial value 15516027.063558
## iter 10 value 731040.931842
## iter 20 value 705007.449877
## iter 30 value 680308.136238
## iter 40 value 622605.747105
## iter 50 value 569956.371891
## iter 60 value 496723.140313
## iter 70 value 447503.870239
## iter 80 value 409801.364671
## iter 90 value 373372.546569
## iter 100 value 343301.038635
## iter 110 value 326303.769421
## iter 120 value 298550.662428
## iter 130 value 267958.962544
## iter 140 value 244306.895889
## iter 150 value 233006.911806
## iter 160 value 231426.820914
## iter 170 value 229984.319059
## iter 180 value 228764.044536
## iter 190 value 228497.124280
## iter 200 value 228470.755506
## iter 210 value 228416.151054
## iter 220 value 228270.277308
## iter 230 value 227796.409754
## iter 240 value 226326.791044
## iter 250 value 224490.232153
## iter 260 value 221468.513961
## iter 270 value 219508.077738
## iter 280 value 218290.403038
## iter 290 value 217070.752760
## iter 300 value 215485.981431
## iter 310 value 214296.654342
## iter 320 value 213363.539916
## iter 330 value 212865.134242
## iter 340 value 212139.364296
## iter 350 value 211412.415595
## iter 360 value 209801.078139
## iter 370 value 208938.368350
## iter 380 value 208789.917665
## iter 390 value 208538.620785
## iter 400 value 208110.940550
## iter 410 value 207833.799289
## iter 420 value 207602.007817
## iter 430 value 207273.866931
## iter 440 value 206944.891320
## iter 450 value 206160.206910
## iter 460 value 205075.675360
## iter 470 value 204205.082758
## iter 480 value 203783.731226
## iter 490 value 203127.543452
## iter 500 value 202550.076358
## final value 202550.076358
## stopped after 500 iterations
## # weights: 91
## initial value 14886844.099549
## iter 10 value 717220.584908
## iter 20 value 688641.758605
## iter 30 value 585115.028958
## iter 40 value 506861.940934
## iter 50 value 470605.229818
## iter 60 value 442905.932178
## iter 70 value 430422.116421
## iter 80 value 417784.199119
## iter 90 value 396359.972903
## iter 100 value 372575.148483
## iter 110 value 339560.668849
## iter 120 value 285883.426984
## iter 130 value 270014.529272
## iter 140 value 260576.677467
## iter 150 value 254935.391595
## iter 160 value 250771.950311
## iter 170 value 247259.794840
## iter 180 value 245727.115164
## iter 190 value 245197.198485
## iter 200 value 244939.636262
## iter 210 value 244582.680991
## iter 220 value 244072.338607
## iter 230 value 243481.983380
## iter 240 value 241910.620373
## iter 250 value 239164.466548
## iter 260 value 234679.701173
## iter 270 value 223218.434658
## iter 280 value 217208.452180
## iter 290 value 215463.901314
## iter 300 value 213643.705714
## iter 310 value 212137.728576
## iter 320 value 211155.737346
## iter 330 value 210587.982617
## iter 340 value 209854.415224
## iter 350 value 209452.973056
## iter 360 value 209159.622128
## iter 370 value 208537.143844
## iter 380 value 208431.209452
## iter 390 value 208360.236895
## iter 400 value 208205.123191
## iter 410 value 208014.691337
## iter 420 value 207919.893162
## iter 430 value 207555.391922
## iter 440 value 207195.737251
## iter 450 value 206785.666157
## iter 460 value 206473.723371
## iter 470 value 206272.446716
## iter 480 value 206121.305534
## iter 490 value 205661.654869
## iter 500 value 205172.078179
## final value 205172.078179
## stopped after 500 iterations
## # weights: 91
## initial value 16222446.000903
## iter 10 value 722559.222668
## iter 20 value 705527.862871
## iter 30 value 599484.958313
## iter 40 value 489251.941378
## iter 50 value 430642.454133
## iter 60 value 407898.850466
## iter 70 value 383558.053279
## iter 80 value 344125.652888
## iter 90 value 313479.666418
## iter 100 value 290852.915665
## iter 110 value 282188.268984
## iter 120 value 277921.946558
## iter 130 value 268405.914068
## iter 140 value 258094.888902
## iter 150 value 242659.966675
## iter 160 value 231418.904488
## iter 170 value 225679.975893
## iter 180 value 222626.662352
## iter 190 value 221363.720926
## iter 200 value 220834.772241
## iter 210 value 220562.047615
## iter 220 value 220447.179785
## iter 230 value 220242.621165
## iter 240 value 219090.298012
## iter 250 value 216280.058599
## iter 260 value 213755.714894
## iter 270 value 212670.415631
## iter 280 value 211846.327047
## iter 290 value 211364.772046
## iter 300 value 210816.899528
## iter 310 value 210280.040030
## iter 320 value 210037.584295
## iter 330 value 209840.769278
## iter 340 value 209651.201310
## iter 350 value 209399.080411
## iter 360 value 209214.154008
## iter 370 value 209127.904916
## iter 380 value 209069.037679
## iter 390 value 209033.206773
## iter 400 value 208998.134871
## iter 410 value 208893.037097
## iter 420 value 208850.012297
## iter 430 value 208783.546606
## iter 440 value 208551.317307
## iter 450 value 208279.641756
## iter 460 value 207919.673591
## iter 470 value 207616.785251
## iter 480 value 207294.371626
## iter 490 value 207057.660751
## iter 500 value 206860.891403
## final value 206860.891403
## stopped after 500 iterations
## # weights: 91
## initial value 15103229.641452
## iter 10 value 718831.356935
## iter 20 value 692130.916422
## iter 30 value 549560.670309
## iter 40 value 452413.963964
## iter 50 value 429104.564441
## iter 60 value 411362.198758
## iter 70 value 391708.345433
## iter 80 value 351756.952346
## iter 90 value 308546.888909
## iter 100 value 288420.948724
## iter 110 value 278101.595402
## iter 120 value 272748.879947
## iter 130 value 266014.084515
## iter 140 value 262286.366345
## iter 150 value 254983.566488
## iter 160 value 247433.462378
## iter 170 value 243422.243322
## iter 180 value 241624.709486
## iter 190 value 240134.022231
## iter 200 value 239742.426182
## iter 210 value 239218.123142
## iter 220 value 238558.495475
## iter 230 value 237032.353890
## iter 240 value 233824.870712
## iter 250 value 229330.985342
## iter 260 value 222658.531040
## iter 270 value 217555.331326
## iter 280 value 214204.810317
## iter 290 value 212140.800240
## iter 300 value 210430.554884
## iter 310 value 209429.517486
## iter 320 value 208314.408741
## iter 330 value 207253.472825
## iter 340 value 206311.271775
## iter 350 value 205568.328002
## iter 360 value 204785.572410
## iter 370 value 204382.807171
## iter 380 value 204357.068735
## iter 390 value 204299.883156
## iter 400 value 204199.052814
## iter 410 value 204135.420111
## iter 420 value 204068.376150
## iter 430 value 204003.729383
## iter 440 value 203755.840997
## iter 450 value 203438.986984
## iter 460 value 203130.223855
## iter 470 value 202991.404756
## iter 480 value 202655.895464
## iter 490 value 202397.449660
## iter 500 value 202247.795767
## final value 202247.795767
## stopped after 500 iterations
## # weights: 91
## initial value 16059386.857834
## iter 10 value 733968.949887
## iter 20 value 705863.331929
## iter 30 value 621571.294208
## iter 40 value 576407.100518
## iter 50 value 514840.465565
## iter 60 value 485657.565779
## iter 70 value 467172.754033
## iter 80 value 433333.173168
## iter 90 value 371705.805748
## iter 100 value 339489.797144
## iter 110 value 292711.828747
## iter 120 value 271795.738680
## iter 130 value 262606.476781
## iter 140 value 255470.009061
## iter 150 value 251667.601247
## iter 160 value 247956.300298
## iter 170 value 242716.900041
## iter 180 value 235013.479372
## iter 190 value 228923.195192
## iter 200 value 226289.653966
## iter 210 value 224618.167479
## iter 220 value 222784.671994
## iter 230 value 221532.757506
## iter 240 value 218490.665877
## iter 250 value 216631.436212
## iter 260 value 214115.103544
## iter 270 value 212537.393424
## iter 280 value 210735.680606
## iter 290 value 209640.665735
## iter 300 value 207802.928752
## iter 310 value 206580.289376
## iter 320 value 205069.070194
## iter 330 value 204396.026876
## iter 340 value 203896.353884
## iter 350 value 203081.065272
## iter 360 value 202557.508445
## iter 370 value 202099.096917
## iter 380 value 202033.164291
## iter 390 value 201975.571016
## iter 400 value 201916.505242
## iter 410 value 201868.593735
## iter 420 value 201838.204564
## iter 430 value 201793.095100
## iter 440 value 201724.828987
## iter 450 value 201587.074293
## iter 460 value 201297.545166
## iter 470 value 200772.688385
## iter 480 value 200538.104748
## iter 490 value 200415.913000
## iter 500 value 200386.797145
## final value 200386.797145
## stopped after 500 iterations
cv_summary <- do.call(rbind, cv_results)
cv_means <- colMeans(cv_summary)
print("Cross-Validation Metrics (Neural Network):")
## [1] "Cross-Validation Metrics (Neural Network):"
print(cv_means)
## RMSE MAE R2
## 2.3755955 1.7082565 0.7450606
# Combine metrics into a single data frame
nn_metrics <- data.frame(
Metric = rep(c("RMSE", "MAE", "R-squared"), each = 2),
Value = c(
nn_train_metrics$RMSE, nn_test_metrics$RMSE,
nn_train_metrics$MAE, nn_test_metrics$MAE,
nn_train_metrics$R2, nn_test_metrics$R2
),
Dataset = rep(c("Training", "Testing"), times = 3)
)
# Load ggplot2 for plotting
library(ggplot2)
# Create the bar plot
metrics_plot <- ggplot(nn_metrics, aes(x = Metric, y = Value, fill = Dataset)) +
geom_bar(stat = "identity", position = "dodge", width = 0.7) +
theme_minimal() +
labs(
title = "Neural Network Metrics: Training vs Testing",
x = "Metric",
y = "Value",
fill = "Dataset"
) +
theme(legend.position = "top") +
scale_fill_manual(values = c("Training" = "steelblue", "Testing" = "darkorange"))
# Print the plot
print(metrics_plot)
#Model Comparision
# Metrics for Random Forest
rf_metrics <- data.frame(
Model = "Random Forest",
RMSE_Train = rf_train_metrics$RMSE,
RMSE_Test = rf_test_metrics$RMSE,
MAE_Train = rf_train_metrics$MAE,
MAE_Test = rf_test_metrics$MAE,
R2_Train = rf_train_metrics$R2,
R2_Test = rf_test_metrics$R2
)
# Metrics for Gradient Boosting
gb_metrics <- data.frame(
Model = "Gradient Boosting",
RMSE_Train = train_metrics$RMSE,
RMSE_Test = test_metrics$RMSE,
MAE_Train = train_metrics$MAE,
MAE_Test = test_metrics$MAE,
R2_Train = train_metrics$R2,
R2_Test = test_metrics$R2
)
# Metrics for Neural Network
nn_metrics <- data.frame(
Model = "Neural Network",
RMSE_Train = nn_train_metrics$RMSE,
RMSE_Test = nn_test_metrics$RMSE,
MAE_Train = nn_train_metrics$MAE,
MAE_Test = nn_test_metrics$MAE,
R2_Train = nn_train_metrics$R2,
R2_Test = nn_test_metrics$R2
)
# Combine metrics for all models
model_comparison <- rbind(rf_metrics, gb_metrics, nn_metrics)
# Display the comparison table
library(knitr)
kable(
model_comparison,
caption = "Model Comparison Metrics",
col.names = c("Model", "RMSE (Train)", "RMSE (Test)", "MAE (Train)", "MAE (Test)", "R2 (Train)", "R2 (Test)")
)
| Model | RMSE (Train) | RMSE (Test) | MAE (Train) | MAE (Test) | R2 (Train) | R2 (Test) |
|---|---|---|---|---|---|---|
| Random Forest | 1.081727 | 2.523936 | 0.7899513 | 1.882638 | 0.9600158 | 0.7160709 |
| Gradient Boosting | 2.097662 | 2.433837 | 1.5466593 | 1.787104 | 0.8040816 | 0.7297126 |
| Neural Network | 2.353712 | 2.330919 | 1.6848334 | 1.679057 | 0.7497582 | 0.7513461 |
library(reshape2)
##
## Attaching package: 'reshape2'
## The following objects are masked from 'package:data.table':
##
## dcast, melt
## The following object is masked from 'package:tidyr':
##
## smiths
# Melt the data for visualization
model_comparison_melted <- melt(model_comparison, id.vars = "Model")
# Plot RMSE (Train vs Test)
ggplot(subset(model_comparison_melted, variable %in% c("RMSE_Train", "RMSE_Test")),
aes(x = Model, y = value, fill = variable)) +
geom_bar(stat = "identity", position = "dodge") +
theme_minimal() +
labs(title = "RMSE Comparison Across Models",
x = "Model",
y = "RMSE",
fill = "Metric") +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
# Plot MAE (Train vs Test)
ggplot(subset(model_comparison_melted, variable %in% c("MAE_Train", "MAE_Test")),
aes(x = Model, y = value, fill = variable)) +
geom_bar(stat = "identity", position = "dodge") +
theme_minimal() +
labs(title = "MAE Comparison Across Models",
x = "Model",
y = "MAE",
fill = "Metric") +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
# Plot R2 (Train vs Test)
ggplot(subset(model_comparison_melted, variable %in% c("R2_Train", "R2_Test")),
aes(x = Model, y = value, fill = variable)) +
geom_bar(stat = "identity", position = "dodge") +
theme_minimal() +
labs(title = "R² Comparison Across Models",
x = "Model",
y = expression(R^2),
fill = "Metric") +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
Random Forest Training RMSE and R²: Random Forest performs the best on the training set, with the lowest RMSE and the highest R² . This indicates that the model captures the training data very well.
On the test set, it has a higher RMSE and a slightly lower R² compared to Neural Network, showing a mild overfitting tendency.
Gradient Boosting
It has a slightly lower test RMSE and a comparable R² compared to Random Forest. This suggests better generalization compared to Random Forest, but not as good as Neural Network.
Neural Network: Neural Network has the highest training RMSE (2.351059) and the lowest R² (0.7503213) On the test set, Neural Network outperforms the others, with the lowest RMSE and the highest R² . This suggests it generalizes better than the other two models.
Conclusion: Neural Network appears to be the best model for this dataset based on the test performance. Random Forest performs very well on the training data but slightly overfits, as indicated by the larger difference between training and test RMSE. Gradient Boosting is a good compromise, with reasonable performance on both training and test sets, but it is slightly outperformed by Neural Network in generalization.
#Neural Network
ggplot(data = data.frame(
Actual = taxi_test$fare_amount,
Predicted = nn_test_preds
), aes(x = Actual, y = Predicted)) +
geom_point(alpha = 0.5, color = "blue") +
geom_abline(slope = 1, intercept = 0, color = "red", linetype = "dashed") +
theme_minimal() +
labs(
title = "Regression Plot: Actual vs Predicted Fare Amount",
x = "Actual Fare Amount",
y = "Predicted Fare Amount"
)
Saving the results
#final_results <- data.frame(
# key = test_keys$key, # Test keys
# fare_amount = nn_test_preds # Predicted fare amounts
#)
# Save the results to a CSV file
#write.csv(final_results, "final_predictions.csv", row.names = FALSE)
# Print confirmation
#cat("Final predictions saved successfully as 'final_predictions.csv'.\n")
#View(final_results)
#Conclusions
Neural Networks were selected as the final model due to their superior generalization capability, reflected in their performance on the test data. This suggests that Neural Networks can better adapt to the complexities and patterns in the dataset, making them the most suitable choice for taxi fare prediction.
The regression plot of actual vs. predicted fares showed that the Neural Network model was able to capture the overall trend in the data effectively, with most predictions closely aligning with actual values.