Libraries used in the project:
#Data manipulation and vizualization
library(data.table)
library(tidyverse)
library(viridis)
library(lubridate)
library(geosphere)
library(ggforce)
library(geojson)
library(leaflet)
library(leaflet.extras)
#linear regression
library(MASS)
library(caTools)
library(car)
library(quantmod)
library(MASS)
library(corrplot)
library(caret)
library(broom)
#regression tree
library(rpart)
library(rpart.plot)
library(rattle)
#bagged tree
library(e1071)
library(ipred)
#random Forest
library(randomForest)
#XGBoost
library(xgboost)
This project revolves around the prediction of a taxi fare in New York city. Our aim is to build a model which would forecast the fare of a taxi ride (tolls inclusive) given basic information, like pickup and dropoff location. Fee for a ride is dependent on a taxi price-list, however, it can vary depending on a route and traffic congestion. Our secondary aim was to compare how various machine learning models cope with the task of a regression. This basic information could later be used to build more advanced models for transportation companies for prediction of ride fares.
Dataset used in this project was made available in 2018 by Google for a Kaggle competition. It consists of 55m rows and variables depicting coordinates of pickup and drop-off, exact timestamp of a pickup, passenger count and total fare for the ride.
Since the dataset is too big to build models based on it, we will randomly select 100k rows. We will also limit the data to records from 2014 and 2015 (prices may have been changing across the years, hence for prediction we are only interested in the latest years available).
#function found on Stack Overflow -
# https://stackoverflow.com/questions/15532810/reading-40-gb-csv-file-into-r-using-bigmemory
fsample <-
function(fname, n, seed, header=FALSE, ..., reader = read.csv)
{
set.seed(seed)
con <- file(fname, open="r")
hdr <- if (header) {
readLines(con, 1L)
} else character()
buf <- readLines(con, n)
n_tot <- length(buf)
repeat {
txt <- readLines(con, n)
if ((n_txt <- length(txt)) == 0L)
break
n_tot <- n_tot + n_txt
n_keep <- rbinom(1, n_txt, n_txt / n_tot)
if (n_keep == 0L)
next
keep <- sample(n_txt, n_keep)
drop <- sample(n, n_keep)
buf[drop] <- txt[keep]
}
reader(textConnection(c(hdr, buf)), header=header, ...)
}
taxis <- fsample(fname = "train.csv", n = 10^6, seed = 123, header = T, reader = read.csv)
closeAllConnections()
taxis <- taxis %>%
filter(substr(key, 1, 4) == "2015" | substr(key, 1, 4) == "2014" )
taxis <- taxis[sample(nrow(taxis), 100000), ]
# Let's save the file not to repeat the same computation again - it's quite time consuming
saveRDS(taxis, "taxis_small.rds")
taxis <- readRDS("taxis_small.rds")
Let’s inspect the data now.
There are 8 variables in total. Key is used as in ID of a ride. ‘pickup_longitude’ and ‘pickup_latitude’ describe the pickup coordinates. Likewise for dropoff longitude and lattitude. The ‘passenger_count’ column gives information about the number of passangers in a ride, ‘pickup_datetime’ for exact time of a pickup and ‘fare_amount’ about the fare.
kable(taxis[1:10, ])
| key | fare_amount | pickup_datetime | pickup_longitude | pickup_latitude | dropoff_longitude | dropoff_latitude | passenger_count | |
|---|---|---|---|---|---|---|---|---|
| 216153 | 2014-12-23 19:42:00.000000189 | 57.33 | 2014-12-23 19:42:00 UTC | -73.78225 | 40.64458 | -74.00796 | 40.74117 | 2 |
| 14300 | 2014-02-10 23:11:08.0000005 | 30.50 | 2014-02-10 23:11:08 UTC | 0.00000 | 0.00000 | 0.00000 | 0.00000 | 1 |
| 7837 | 2014-02-21 19:50:37.0000003 | 12.50 | 2014-02-21 19:50:37 UTC | -73.99249 | 40.72466 | -73.98779 | 40.74805 | 2 |
| 13532 | 2015-04-05 20:15:52.0000001 | 11.00 | 2015-04-05 20:15:52 UTC | -74.00418 | 40.73808 | -73.99505 | 40.75536 | 1 |
| 53063 | 2014-11-03 19:20:00.00000025 | 15.50 | 2014-11-03 19:20:00 UTC | -73.96572 | 40.75459 | -73.97586 | 40.79299 | 2 |
| 53704 | 2014-04-11 17:36:00.000000174 | 7.00 | 2014-04-11 17:36:00 UTC | -73.98159 | 40.77372 | -73.98055 | 40.78513 | 1 |
| 115101 | 2014-05-12 00:21:00.00000012 | 8.50 | 2014-05-12 00:21:00 UTC | -73.98630 | 40.74697 | -73.98251 | 40.76895 | 5 |
| 205808 | 2014-03-23 17:28:00.000000118 | 15.00 | 2014-03-23 17:28:00 UTC | -73.99020 | 40.75125 | -73.98813 | 40.72725 | 1 |
| 31546 | 2014-05-08 13:19:52.0000006 | 40.33 | 2014-05-08 13:19:52 UTC | -73.87106 | 40.77375 | -73.99479 | 40.75389 | 2 |
| 55931 | 2015-02-08 00:49:58.0000004 | 16.00 | 2015-02-08 00:49:58 UTC | -73.98228 | 40.76895 | -74.00237 | 40.73877 | 1 |
There seem to be records that are clearly wrong. It is not possible for fare amount to be negative. Minimum and maximum coordinates are not accurate too. We will focus on filtering wrong data and outliers later.
#Confirmation that we indeed have 100 000 observations
nrow(taxis)
## [1] 100000
summary(taxis)
## key fare_amount pickup_datetime pickup_longitude
## Length:100000 Min. : -5.50 Length:100000 Min. :-81.23
## Class :character 1st Qu.: 6.50 Class :character 1st Qu.:-73.99
## Mode :character Median : 9.50 Mode :character Median :-73.98
## Mean : 12.93 Mean :-72.61
## 3rd Qu.: 14.50 3rd Qu.:-73.97
## Max. :362.33 Max. : 0.00
## pickup_latitude dropoff_longitude dropoff_latitude passenger_count
## Min. : 0.00 Min. :-171.80 Min. : 0.00 Min. :0.000
## 1st Qu.:40.74 1st Qu.: -73.99 1st Qu.:40.73 1st Qu.:1.000
## Median :40.75 Median : -73.98 Median :40.75 Median :1.000
## Mean :40.00 Mean : -72.60 Mean :39.99 Mean :1.692
## 3rd Qu.:40.77 3rd Qu.: -73.96 3rd Qu.:40.77 3rd Qu.:2.000
## Max. :41.40 Max. : 0.00 Max. :69.70 Max. :6.000
str(taxis)
## 'data.frame': 100000 obs. of 8 variables:
## $ key : chr "2014-12-23 19:42:00.000000189" "2014-02-10 23:11:08.0000005" "2014-02-21 19:50:37.0000003" "2015-04-05 20:15:52.0000001" ...
## $ fare_amount : num 57.3 30.5 12.5 11 15.5 ...
## $ pickup_datetime : chr "2014-12-23 19:42:00 UTC" "2014-02-10 23:11:08 UTC" "2014-02-21 19:50:37 UTC" "2015-04-05 20:15:52 UTC" ...
## $ pickup_longitude : num -73.8 0 -74 -74 -74 ...
## $ pickup_latitude : num 40.6 0 40.7 40.7 40.8 ...
## $ dropoff_longitude: num -74 0 -74 -74 -74 ...
## $ dropoff_latitude : num 40.7 0 40.7 40.8 40.8 ...
## $ passenger_count : int 2 1 2 1 2 1 5 1 2 1 ...
The average fare for a taxi ride was equal to 12.5$, while the median was 9.5$. Therefore we can assume that the distribution of the fee is slightly/moderately skewed to the left. Standard deviation is as high as 11.4$.
mean(taxis$fare_amount)
## [1] 12.9287
median(taxis$fare_amount)
## [1] 9.5
sd(taxis$fare_amount)
## [1] 11.39675
There are also no missing values or duplicate rows in the dataset.
#do we have NAs
colSums(is.na(taxis)) %>%
sort()
## key fare_amount pickup_datetime pickup_longitude
## 0 0 0 0
## pickup_latitude dropoff_longitude dropoff_latitude passenger_count
## 0 0 0 0
#checking for duplicate values
taxis[duplicated(taxis$key),]
## [1] key fare_amount pickup_datetime pickup_longitude
## [5] pickup_latitude dropoff_longitude dropoff_latitude passenger_count
## <0 wierszy> (lub 'row.names' o zerowej dlugosci)
Since there are no duplicate values, we will delete the ‘key’ column.
#not necessery
taxis$key <- NULL
The ‘pickup_datetime’ column also needs changing to a date format.
#changing pickup_datetime to date format
taxis$pickup_datetime <- ymd_hms(taxis$pickup_datetime)
class(taxis$pickup_datetime)
## [1] "POSIXct" "POSIXt"
Let’s inspect the fare amount more. Based on the distribution we can say, that most of the values are between just above 0 and 30$. There is a slight spike at around 50$. These observations could be rides to or from an airport which have fixed fees of 52$. More on that later on.
There are also 23 records with the passenger count of 0. These could be rides that were canceled for some reason before the pickup and the driver had to wait for the client (there is an extra fee for that) or for example the driver was asked to deliver a package to someone. You can see, that the average fee for a ride with 0 passengers is higher, however the sample size is very small. Nevertheless, we can assume that these observations are safe to ignore and therefore we will filter them out. What catches an eye is that there are more rides with 5 passengers than with 3 and 4. This could be related to how many seats there are available in a taxi car.
#fare_amount
ggplot(taxis,
aes(x = fare_amount)) +
geom_histogram(fill = "lightblue",
bins = 100) +
theme_bw()
ggplot(taxis,
aes(x = fare_amount)) +
geom_histogram(fill = "lightblue",
bins = 100) +
theme_bw() +
facet_zoom(xlim = c(0, 60))
# boxplot(taxis$fare_amount)
min(taxis$fare_amount)
## [1] -5.5
max(taxis$fare_amount)
## [1] 362.33
#passangers
taxis %>%
group_by(passenger_count) %>%
summarise(number_of_rides = n(), avg_fee = mean(fare_amount))
## # A tibble: 7 x 3
## passenger_count number_of_rides avg_fee
## <int> <int> <dbl>
## 1 0 23 15.1
## 2 1 70407 12.7
## 3 2 13994 13.7
## 4 3 4317 13.4
## 5 4 2009 13.0
## 6 5 5662 13.2
## 7 6 3588 12.8
Therefore we will only keep the fare amount higher than 2.5$ and lower than 200$. 2.5$ is a starting fare for a ride in NYC (based on the information found on the NYC taxi website, please note however, that it’s the latest information as of 2022. Prices may have changed during this time, but probably not as significantly. It may be a good idea to try to find price-list from 2014 and 2015 though. https://www1.nyc.gov/site/tlc/passengers/taxi-fare.page). Moreover, we will filter out rides with 0 passengers.
taxis <- taxis %>%
filter(fare_amount >= 2.5 & fare_amount < 200 & passenger_count > 0)
#Only 48 observations filtered out. Not many...
nrow(taxis)
## [1] 99952
Additionally, we will also create additional variables - day of the week, month, hour and the year. Moreover, we will add a column with special surcharge depending on the time of the day - night courses and courses during the highest traffic are more expensive (based on the fare table on the NYC taxi website).
taxis <- taxis %>%
mutate(day = wday(pickup_datetime, label = T),
month_label = month(pickup_datetime, label = T),
hour_exact = hour(pickup_datetime),
year = year(pickup_datetime)) %>%
mutate(time_of_day = ifelse(hour_exact >=20 | hour_exact < 6, "overnight",
ifelse(hour_exact >=16 & hour_exact <20 & day %in% c("Mon", "Tue", "Wed", "Thu", "Fri"), "rush", "normal"))
) %>%
mutate(time_of_day = ifelse(day %in% c("Sat", "Sun"), "overnight", time_of_day))
As can be seen on the plot below, the number of rides rises gradually during the week, with most rides on Saturday, and drops significantly on Sunday.
taxis %>%
mutate(day = factor(day, levels = c("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"))) %>%
ggplot() +
geom_histogram(aes(x = day), stat = "count", fill = "#69b3a2", alpha = 0.6) +
theme_minimal() +
scale_y_continuous(labels = function(n)format(n, big.mark = " ")) +
theme(panel.grid.major.x = element_blank(),
panel.grid.minor.x = element_blank())
We can also notice, that the majority of rides, which may be somewhat surprising, happened during the overnight hours.
taxis %>%
ggplot(aes(x = round(fare_amount), fill = time_of_day)) +
geom_histogram(stat = "count", alpha = 0.6) +
theme_minimal() +
scale_y_continuous(labels = function(n)format(n, big.mark = " ")) +
theme(panel.grid.major.x = element_blank(),
panel.grid.minor.x = element_blank()) +
facet_zoom(xlim = c(0, 20))
We can safely assume, that the fare will be most dependent on the distance of the ride. The standard charge for 1/5 mile in 2022 is equal to 0.50$. I will therefore calculate the Haversine spatial distance between two coordinates.
This function outputs distance in meters - hence we will divide it by 1000 to get distance in kilometers. But first we need to filter incorrect coordinates - the function won’t accept them.
taxis <- taxis %>%
filter(pickup_latitude > -90 & pickup_latitude < 90 &
dropoff_latitude > -90 & dropoff_latitude < 90 &
pickup_longitude > -180 & pickup_longitude < 180 &
dropoff_longitude > -180 & dropoff_longitude < 180) %>%
mutate(distance = (distHaversine(cbind(pickup_longitude, pickup_latitude), cbind(dropoff_longitude, dropoff_latitude)))/1000
)
Some of the records have very strange coordinates - we will only use only those near the NYC City. We will select all the coordinates inside the area spanned by points most on the north, south, east and west of NYC (plus a margin of error). It may not be exactly accurate - it may be a better idea to use geojson data of New York boarders - however it would be difficult to implement. Another issue is that some rides could start in New York City but end in one of the neighboring cities. This way we could filter some of the correct observations, however it should only apply to a fraction of observations.
taxis <- taxis %>%
filter(pickup_longitude < -73 & pickup_longitude > -74.3
& pickup_latitude < 41.7 & pickup_latitude > 40.5
& dropoff_longitude < -73 & dropoff_longitude > -74.3
& dropoff_latitude < 41.7 & dropoff_latitude > 40.5
)
nrow(taxis)
## [1] 97918
#Here are many as 2 082 records have been omitted - 2% of all observations
In the next step we will check the relationship between the fare and the distance. Based on the plot, it appears that there is a linear relationship between the two. There also seem to be observations with low distance, but a very high fare - it is not entirely impossible that these are rides that started and ended in the same location (example situation: someone wanted to do something quickly in the city center, took a taxi and then returned with the same taxi). It is difficult to judge, hence we will keep these observations, especially that there aren’t many of them.
taxis %>%
ggplot(aes(x = distance, y = fare_amount)) +
geom_point() +
geom_smooth(method = "lm") +
theme_minimal()
## `geom_smooth()` using formula = 'y ~ x'
nrow(taxis %>%
filter(distance <2 & fare_amount > 40))
## [1] 167
# ggplot(taxis, aes(x = distance, y = fare_amount)) +
# geom_bin2d(bins = 100) +
# scale_fill_continuous(type = "viridis") +
# theme_bw() + xlim (0, 30)
We will also try to present pickup locations on the map using leaflet package. Most of he pickups happened in the very center of the city, however, there were also two spots on the south-east and north-east of the city with considerable number of pickups - these are two airports. As can be found on the NYC taxi website - rides to and from the airport can have additional surcharges or even fixed fees. We will therefore add additional variable - stating whether the course started or ended at an airport.
geojson = jsonlite::fromJSON("https://services5.arcgis.com/GfwWNkhOj9bNBqoJ/arcgis/rest/services/NYC_Borough_Boundary/FeatureServer/0/query?where=1=1&outFields=*&outSR=4326&f=pgeojson",
simplifyVector = FALSE)
#This one can take a while
leaflet(taxis) %>%
addTiles() %>%
addCircleMarkers(lat = ~pickup_latitude, lng = ~pickup_longitude, radius = 1, weight = 0.5) %>%
setView(lng = -74.00, lat = 40.7128, zoom = 10.5) %>%
addGeoJSON(geojson, weight = 0.5)
leaflet(taxis) %>%
addTiles() %>%
addHeatmap(lng= ~pickup_longitude, lat = ~pickup_latitude, intensity = 0.5,
blur = 20, max = 400, radius = 15, cellSize = 3) %>%
setView(lng = -74.00, lat = 40.7128, zoom = 12)
#airports' coordinates
jfk = c(-73.78222, 40.644166)
newark = c(-74.175, 40.69)
la_guardia = c(-73.87, 40.77)
#pickup or dropoff at airport defined as one of those within a distance of 2km from the above coordinate.
#Larger distance seems to result in to many rides not connected to the airport (based on the fare_distribution).
taxis <- taxis %>%
mutate(airport = ifelse(
(distHaversine(cbind(pickup_longitude, pickup_latitude), jfk))<2000 | (distHaversine(cbind(dropoff_longitude, dropoff_latitude), jfk))<2000,
"JFK", "Not an airport"
)) %>%
mutate(airport = ifelse(
(distHaversine(cbind(pickup_longitude, pickup_latitude), newark))<2000 | (distHaversine(cbind(dropoff_longitude, dropoff_latitude), newark))<2000,
"Newark", airport
)) %>%
mutate(airport = ifelse(
(distHaversine(cbind(pickup_longitude, pickup_latitude), la_guardia))<2000 | (distHaversine(cbind(dropoff_longitude, dropoff_latitude), la_guardia))<2000,
"La_Guardia", airport
))
#Airports - pickup and drop-off locations (ride could start at an airport and end in a city center and vice versa)
taxis %>%
filter(airport %in% c("JFK", "Newark", "La_Guardia")) %>%
leaflet() %>%
addTiles() %>%
addCircleMarkers(lat = ~pickup_latitude, lng = ~pickup_longitude, radius = 1, weight = 0.5) %>%
addCircleMarkers(lat = ~dropoff_latitude, lng = ~dropoff_longitude, radius = 1, weight = 0.5, color = "red") %>%
setView(lng = -74.00, lat = 40.7128, zoom = 10.5) %>%
addGeoJSON(geojson, weight = 0.5)
taxis %>%
filter(airport %in% c("JFK", "Newark", "La_Guardia")) %>%
ggplot(aes(x = fare_amount)) +
geom_histogram(fill = "lightblue",
bins = 100) +
ggtitle("Distriution of fare of rides labeled as to or from an airport") +
theme_bw()
# saveRDS(taxis, "taxis_cleaned.rds")
taxis <- readRDS("taxis_cleaned.rds")
Firstly, we will divide the dataset into training and testing sets in a proportion of 0.75 and 0.25 respectively.
set.seed(123)
training_obs <- createDataPartition(taxis$fare_amount,
p = 0.75,
list = FALSE)
taxis_train <- taxis[training_obs,]
taxis_test <- taxis[-training_obs,]
First of all, as a benchmark to other models, we are going to use a simple linear regression model. In order to select the variables for the model, we are going to use a stepwise feature selection with Akaikke Information Criterion as a measure for selection.
taxis_linear_full <- lm(fare_amount~.,
data = taxis_train)
taxis_linear_step <- stepAIC(taxis_linear_full, direction = "both",
trace = FALSE)
summary(taxis_linear_step)
##
## Call:
## lm(formula = fare_amount ~ pickup_datetime + pickup_longitude +
## dropoff_longitude + dropoff_latitude + passenger_count +
## day + month_label + hour_exact + time_of_day + distance +
## airport, data = taxis_train)
##
## Residuals:
## Min 1Q Median 3Q Max
## -58.638 -1.785 -0.628 0.963 182.954
##
## Coefficients:
## Estimate Std. Error t value
## (Intercept) 960.594442682802 82.309905727980 11.670
## pickup_datetime 0.000000006953 0.000000001302 5.341
## pickup_longitude 4.366915785098 0.687260476782 6.354
## dropoff_longitude 1.423254646464 0.586905388377 2.425
## dropoff_latitude -13.147953847569 0.584894395887 -22.479
## passenger_count 0.027860411298 0.012441911649 2.239
## day.L 0.433134242843 0.044814655480 9.665
## day.Q 0.317829594737 0.055999295918 5.676
## day.C -0.167995728411 0.044913563899 -3.740
## day^4 0.500808790004 0.047050057107 10.644
## day^5 -0.110193960457 0.044876301453 -2.456
## day^6 0.066295030970 0.044570083300 1.487
## month_label.L 0.340158413097 0.063147497122 5.387
## month_label.Q 0.104632760296 0.062706134781 1.669
## month_label.C 0.303063023229 0.063112057550 4.802
## month_label^4 -0.120512421066 0.062037024296 -1.943
## month_label^5 -0.109217086379 0.063059174460 -1.732
## month_label^6 0.024733350584 0.062346589956 0.397
## month_label^7 0.168242442674 0.062395625988 2.696
## month_label^8 0.034748336004 0.061786573536 0.562
## month_label^9 -0.154846124874 0.062199276978 -2.490
## month_label^10 -0.068425598886 0.062834288917 -1.089
## month_label^11 -0.008574341985 0.063494263407 -0.135
## hour_exact 0.039531357451 0.002856561015 13.839
## time_of_dayovernight -1.619848574308 0.049056537224 -33.020
## time_of_dayrush -0.489145639253 0.055965883986 -8.740
## distance 2.381602179453 0.006683247695 356.354
## airportLa_Guardia 5.267359196104 0.163458383247 32.224
## airportNewark 33.391781766182 0.474353296184 70.394
## airportNot an airport -1.532476331660 0.192035271819 -7.980
## Pr(>|t|)
## (Intercept) < 0.0000000000000002 ***
## pickup_datetime 0.00000009289673432 ***
## pickup_longitude 0.00000000021088644 ***
## dropoff_longitude 0.015310 *
## dropoff_latitude < 0.0000000000000002 ***
## passenger_count 0.025143 *
## day.L < 0.0000000000000002 ***
## day.Q 0.00000001387226523 ***
## day.C 0.000184 ***
## day^4 < 0.0000000000000002 ***
## day^5 0.014071 *
## day^6 0.136905
## month_label.L 0.00000007197176124 ***
## month_label.Q 0.095197 .
## month_label.C 0.00000157411082989 ***
## month_label^4 0.052070 .
## month_label^5 0.083282 .
## month_label^6 0.691584
## month_label^7 0.007011 **
## month_label^8 0.573850
## month_label^9 0.012794 *
## month_label^10 0.276164
## month_label^11 0.892580
## hour_exact < 0.0000000000000002 ***
## time_of_dayovernight < 0.0000000000000002 ***
## time_of_dayrush < 0.0000000000000002 ***
## distance < 0.0000000000000002 ***
## airportLa_Guardia < 0.0000000000000002 ***
## airportNewark < 0.0000000000000002 ***
## airportNot an airport 0.00000000000000148 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 4.566 on 73410 degrees of freedom
## Multiple R-squared: 0.8293, Adjusted R-squared: 0.8292
## F-statistic: 1.229e+04 on 29 and 73410 DF, p-value: < 0.00000000000000022
#function getRegressionMetrics was created by Paweł Sakowski for his Machine Learning 2 classes.
source("C:/Projects/taxi-fare-prediction-main/source/getRegressionMetrics.R")
model_linear_results_train <- getRegressionMetrics(real = taxis_train$fare_amount,
predicted = predict(taxis_linear_step, taxis_train))
model_linear_results_test <- getRegressionMetrics(real = taxis_test$fare_amount,
predicted = predict(taxis_linear_step, taxis_test))
model_linear_results_train
## MSE RMSE MAE MedAE MSLE R2
## 1 20.83955 4.565036 2.29565 1.52815 0.06435022 0.8292523
model_linear_results_test
## MSE RMSE MAE MedAE MSLE R2
## 1 21.11537 4.595146 2.321552 1.535537 0.06446535 0.8277197
Root mean squared error (RMSE) on the test data was equal to 4.59$ and R^2 to 0.83.
We will also inspect the collinearity of variables with the variance inflation factor measure (VIF). Based on Fox & Monette (1992), we will take GVIF^(1/(2*Df)) into consideration and apply to it a standard rule of thumb regarding VIF - that values above 5 indicate a collinearity. Hence, we will assume that there is no such problem in our model, which may be a little bit surprising given that some of the variables, like day, month or hour were extracted from the pickup_datetime variable.
vif(taxis_linear_step)
## GVIF Df GVIF^(1/(2*Df))
## pickup_datetime 1.118902 1 1.057782
## pickup_longitude 2.151331 1 1.466742
## dropoff_longitude 1.486134 1 1.219071
## dropoff_latitude 1.212819 1 1.101281
## passenger_count 1.002183 1 1.001091
## day 1.714028 6 1.045927
## month_label 1.135958 11 1.005811
## hour_exact 1.215933 1 1.102694
## time_of_day 1.978713 2 1.186030
## distance 2.242201 1 1.497398
## airport 4.331455 3 1.276749
As we earlier showed, the distribution of the independent variable - fare amount, is skewed, therefore we will try to normalize it with a log transformation.
#Log transformation
skewness(taxis_train$fare_amount)
## [1] 3.141872
taxis_train %>%
ggplot(aes(x = log(fare_amount))) +
geom_histogram(fill = "lightblue",
bins = 100) +
geom_density(alpha = .2, fill = "#FF6666") +
theme_minimal()
taxis_linear_log <- lm(log(fare_amount)~.,
data = taxis_train)
taxis_linear_step_log <- stepAIC(taxis_linear_log, direction = "both",
trace = FALSE)
summary(taxis_linear_step_log)
##
## Call:
## lm(formula = log(fare_amount) ~ pickup_longitude + pickup_latitude +
## dropoff_longitude + dropoff_latitude + passenger_count +
## day + month_label + hour_exact + year + time_of_day + distance +
## airport, data = taxis_train)
##
## Residuals:
## Min 1Q Median 3Q Max
## -5.7085 -0.1908 0.0056 0.1864 4.0297
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 74.6928101 8.9595260 8.337 < 0.0000000000000002 ***
## pickup_longitude 0.1717351 0.0502829 3.415 0.000637 ***
## pickup_latitude -0.7058026 0.0551073 -12.808 < 0.0000000000000002 ***
## dropoff_longitude 0.1770871 0.0434273 4.078 0.0000455155 ***
## dropoff_latitude -1.1293039 0.0438583 -25.749 < 0.0000000000000002 ***
## passenger_count 0.0031838 0.0008977 3.547 0.000390 ***
## day.L 0.0359319 0.0032333 11.113 < 0.0000000000000002 ***
## day.Q 0.0203488 0.0040423 5.034 0.0000004815 ***
## day.C -0.0101024 0.0032404 -3.118 0.001824 **
## day^4 0.0376558 0.0033953 11.091 < 0.0000000000000002 ***
## day^5 -0.0100763 0.0032377 -3.112 0.001858 **
## day^6 0.0042003 0.0032156 1.306 0.191487
## month_label.L 0.0484429 0.0049514 9.784 < 0.0000000000000002 ***
## month_label.Q 0.0012182 0.0045242 0.269 0.787734
## month_label.C 0.0260488 0.0045560 5.718 0.0000000109 ***
## month_label^4 -0.0090334 0.0044759 -2.018 0.043574 *
## month_label^5 -0.0105810 0.0045499 -2.326 0.020045 *
## month_label^6 -0.0040844 0.0044984 -0.908 0.363893
## month_label^7 0.0084752 0.0045016 1.883 0.059743 .
## month_label^8 -0.0010883 0.0044578 -0.244 0.807124
## month_label^9 -0.0136886 0.0044879 -3.050 0.002288 **
## month_label^10 -0.0097880 0.0045334 -2.159 0.030847 *
## month_label^11 -0.0012079 0.0045812 -0.264 0.792034
## hour_exact 0.0034029 0.0002064 16.484 < 0.0000000000000002 ***
## year 0.0132916 0.0029671 4.480 0.0000074864 ***
## time_of_dayovernight -0.1187080 0.0035587 -33.357 < 0.0000000000000002 ***
## time_of_dayrush -0.0371935 0.0040398 -9.207 < 0.0000000000000002 ***
## distance 0.1557028 0.0004833 322.188 < 0.0000000000000002 ***
## airportLa_Guardia 1.0979388 0.0127174 86.334 < 0.0000000000000002 ***
## airportNewark 0.8534392 0.0351048 24.311 < 0.0000000000000002 ***
## airportNot an airport 0.9546220 0.0147445 64.744 < 0.0000000000000002 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.3294 on 73409 degrees of freedom
## Multiple R-squared: 0.7226, Adjusted R-squared: 0.7224
## F-statistic: 6373 on 30 and 73409 DF, p-value: < 0.00000000000000022
getRegressionMetrics(real = log(taxis_test$fare_amount),
predicted = predict(taxis_linear_step_log, taxis_test))
## MSE RMSE MAE MedAE MSLE R2
## 1 0.107622 0.3280579 0.2386273 0.1868871 0.01025206 0.7246729
The adjusted R^2 for a model with a logarithm of fare_amount is however much lower than for previous model (RMSE is different because of different measure - a logarithm). We will thus use the model without normalization. The linear model could be further expanded with with use a Yeo-Johnson transformation instead of log transformation. We will, however, focus more on different models and use linear regression mainly as a benchmark. We have also ommitted the whole diagnostics of a model - which would be necessary if the model was about to be used to make real life predictions about the fare. Based on the below plot, for example, we can suspect a heteroskedasticity to be present.
lares::mplot_lineal(tag = taxis_test$fare_amount,
score = predict(taxis_linear_step, taxis_test),
subtitle = "Taxi fare regression model",
model_name = "Stepwise feature selection")
As the next model for the prediction we will use a regression tree. At first we will use function from the rpart package.
#Decision tree
#sprawdzic, czy na pewno method = anova
set.seed(123)
taxis_tree_rpart <- rpart(fare_amount~.,
data = taxis_train,
method = "anova")
taxis_tree_rpart
## n= 73440
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 73440 8963266.00 12.894630
## 2) distance< 8.268036 67316 2525014.00 10.375750
## 4) distance< 3.359515 50979 1057048.00 8.281781
## 8) distance< 1.748563 29333 693577.70 6.819109 *
## 9) distance>=1.748563 21646 215674.40 10.263880 *
## 5) distance>=3.359515 16337 546928.50 16.909890
## 10) distance< 5.472393 10807 216812.50 14.858090 *
## 11) distance>=5.472393 5530 195707.80 20.919640 *
## 3) distance>=8.268036 6124 1316349.00 40.582590
## 6) distance< 16.06503 4294 376764.70 33.809500 *
## 7) distance>=16.06503 1830 280382.20 56.475270
## 14) dropoff_longitude>=-74.04027 1719 149519.00 54.655040 *
## 15) dropoff_longitude< -74.04027 111 36965.17 84.664230 *
rpart.plot(taxis_tree_rpart)
plotcp(taxis_tree_rpart)
taxis_tree_rpart$cptable
## CP nsplit rel error xerror xstd
## 1 0.57143267 0 1.0000000 1.0000356 0.015539960
## 2 0.10275683 1 0.4285673 0.4299987 0.010238835
## 3 0.07354490 2 0.3258105 0.3263641 0.010300142
## 4 0.01648912 3 0.2522656 0.2545888 0.009738799
## 5 0.01499544 4 0.2357765 0.2358555 0.009843046
## 6 0.01047587 5 0.2207810 0.2227201 0.009876202
## 7 0.01000000 6 0.2103052 0.2143213 0.009695108
decision_tree_results_rpart_train <- getRegressionMetrics(real = taxis_train$fare_amount,
predicted = predict(taxis_tree_rpart, taxis_train))
decision_tree_results_rpart_test <- getRegressionMetrics(real = taxis_test$fare_amount,
predicted = predict(taxis_tree_rpart, taxis_test))
decision_tree_results_rpart_train
## MSE RMSE MAE MedAE MSLE R2
## 1 25.6675 5.06631 2.745766 1.819109 0.07826417 0.7896948
decision_tree_results_rpart_test
## MSE RMSE MAE MedAE MSLE R2
## 1 26.49949 5.147765 2.786504 1.819109 0.07972453 0.7837907
For the model, the optimal complexity parameter was selected at 0.01, which resulted in only moderately complex tree and only 6 splits. The complexity parameter of 0.01 seems to be the lowest that is checked by the algorithm. We will therefore try to tune it with the caret package. The accuracy is much lower than for the linear regression - RMSE on the test data is equal to 5.15 and R^2 to 0.78.
tc <- trainControl(method = "cv", number = 10)
cp.grid <- expand.grid(cp = seq(0, 0.03, 0.001))
set.seed(123)
taxis_tree <- train(fare_amount~.,
data = taxis_train,
method = "rpart",
trControl = tc,
tuneGrid = cp.grid)
taxis_tree
## CART
##
## 73440 samples
## 13 predictor
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 66096, 66095, 66095, 66097, 66097, 66097, ...
## Resampling results across tuning parameters:
##
## cp RMSE Rsquared MAE
## 0.000 4.656515 0.8243952 2.400041
## 0.001 4.559198 0.8296517 2.413487
## 0.002 4.603924 0.8261317 2.476781
## 0.003 4.747481 0.8150316 2.606898
## 0.004 4.840488 0.8077834 2.672070
## 0.005 4.985847 0.7962043 2.689315
## 0.006 4.980011 0.7966745 2.690651
## 0.007 4.975160 0.7970554 2.691409
## 0.008 5.023356 0.7930449 2.720297
## 0.009 5.071772 0.7890985 2.750729
## 0.010 5.071772 0.7890985 2.750729
## 0.011 5.189889 0.7792065 2.762830
## 0.012 5.192004 0.7790481 2.762227
## 0.013 5.192004 0.7790481 2.762227
## 0.014 5.192004 0.7790481 2.762227
## 0.015 5.212626 0.7772135 2.780427
## 0.016 5.338745 0.7663561 2.919246
## 0.017 5.556893 0.7469355 3.305430
## 0.018 5.556893 0.7469355 3.305430
## 0.019 5.556893 0.7469355 3.305430
## 0.020 5.556893 0.7469355 3.305430
## 0.021 5.556893 0.7469355 3.305430
## 0.022 5.556893 0.7469355 3.305430
## 0.023 5.556893 0.7469355 3.305430
## 0.024 5.556893 0.7469355 3.305430
## 0.025 5.556893 0.7469355 3.305430
## 0.026 5.556893 0.7469355 3.305430
## 0.027 5.556893 0.7469355 3.305430
## 0.028 5.556893 0.7469355 3.305430
## 0.029 5.556893 0.7469355 3.305430
## 0.030 5.556893 0.7469355 3.305430
##
## RMSE was used to select the optimal model using the smallest value.
## The final value used for the model was cp = 0.001.
The final model used cp value of (0.001), which resulted in quite a complex tree with 24 splits and 25 terminal nodes. Distance has turned out to be the most important variable in explaining the fare amount, followed by drop-off longitude and the information that the ride is not to or from an airport. For some reason, only the month of November had an importance above 0.
fancyRpartPlot(taxis_tree$finalModel, cex = 0.5)
varImp(taxis_tree)
## rpart variable importance
##
## only 20 most important variables shown (out of 33)
##
## Overall
## distance 100.000
## dropoff_longitude 45.043
## airportNot an airport 34.134
## pickup_longitude 31.443
## dropoff_latitude 28.329
## pickup_latitude 19.997
## airportNewark 13.994
## airportLa_Guardia 10.307
## time_of_dayovernight 6.017
## month_label^11 4.988
## hour_exact 3.035
## pickup_datetime 2.336
## month_label.C 0.000
## `month_label^4` 0.000
## day.C 0.000
## `day^4` 0.000
## `month_label^9` 0.000
## day.L 0.000
## `month_label^8` 0.000
## passenger_count 0.000
model_tree_results_train <- getRegressionMetrics(real = taxis_train$fare_amount,
predicted = predict(taxis_tree, taxis_train))
model_tree_results_test <- getRegressionMetrics(real = taxis_test$fare_amount,
predicted = predict(taxis_tree, taxis_test))
model_tree_results_train
## MSE RMSE MAE MedAE MSLE R2
## 1 18.94488 4.352572 2.3795 1.57647 0.06554675 0.8447762
model_tree_results_test
## MSE RMSE MAE MedAE MSLE R2
## 1 19.7531 4.444446 2.42287 1.57647 0.06753319 0.8388345
Results of the regression tree with the complexity parameter of 0.001 seem to be much better than of the tree with default cp. RMSE is better than the linear regression’s too - it’s equal to 4.44 (4.59 for the linear model). It seems to be slightly overfitted however, with RMSE on the train data of 4.35.
The next model used in the analysis is the bagged tree - a tree based on a bootstrap aggregation. Bootstrap aggregation, or bagging, for a regression tree is a method in which a sample of data will be selected with a replacement (which means that some observations will be left out completely) and then the results will be averaged over multiple trees build in such a way. To construct a bagged tree we will use bagging() function from ipred package.
We will try to use 100 bootstrap aggregations and the complexity parameter of 0.001 (that is the optimal cp from a single tree from the previous model). It would be interesting to see, however, how would the results look like when the cp parameter was used for each tree seperately.
set.seed(123)
bag_tree <- bagging(
formula = fare_amount ~ .,
data = taxis_train,
nbagg = 100,
coob = TRUE,
control = rpart.control(cp = 0.001)
)
bag_tree
#Saving the results of the model since it takes quite a long to compute it
#The rds file with the model weights as much as 1.5gb...Therefore I'm also going
#to save the regression metrics, so there is no need to load the model
# saveRDS(bag_tree, "bag_tree.rds")
bag_tree <- readRDS("bag_tree.rds")
model_bagg_results_train <- getRegressionMetrics(real = taxis_train$fare_amount,
predicted = predict(bag_tree, taxis_train))
model_bagg_results_test <-getRegressionMetrics(real = taxis_test$fare_amount,
predicted = predict(bag_tree, taxis_test))
saveRDS(model_bagg_results_train, "model_bagg_results_train.rds")
saveRDS(model_bagg_results_test, "model_bagg_results_test.rds")
model_bagg_results_train <- readRDS("model_bagg_results_train.rds")
model_bagg_results_test <- readRDS("model_bagg_results_test.rds")
model_bagg_results_train
## MSE RMSE MAE MedAE MSLE R2
## 1 17.67833 4.20456 2.322245 1.536077 0.06252857 0.8551537
model_bagg_results_test
## MSE RMSE MAE MedAE MSLE R2
## 1 18.8018 4.336104 2.369062 1.550254 0.06411911 0.8465961
Here the results are significantly better than for the previous models. RMSE on the test data is equal to 4.33 and R^2 to 0.85. It’s lower than on the train data, which may indicate that the model was slightly overfitted, however the difference is not enormous.
# VI <- data.frame(var=names(taxis_train[,-1]), imp=varImp(bag_tree))
#
# variables_importance <- data.frame(variable = row.names(varImp(bag_tree)),
# importance = varImp(bag_tree)$Overall)
#
# ggplot(data = variables_importance, aes(x = reorder(variable, -importance), y = importance)) +
# geom_bar(stat = "identity", fill = "steelblue2") +
# labs(x = "Variable") +
# theme_minimal() +
# theme(panel.grid.major.x = element_blank(),
# panel.grid.minor.x = element_blank(),
# axis.text.x = element_text(angle = 70, hjust=1, size = 11))
In the next part we are going to build a random forest model. Random forest uses multiple decision trees with additional randomness regarding variables used to construct the tree and sample used. Then, the mean or average prediction of the individual trees is returned.
Firstly, let’s use the randomForest package to build the model.
set.seed(123)
taxis_forest <- randomForest(fare_amount~., data = taxis_train, importance = T)
saveRDS(taxis_forest, "taxis_forest.rds")
The Random forest was built based on 500 trees and 4 variables tried at each split. However the error seems to be declining only slightly for ntree higher than 100 trees.
taxis_forest <- readRDS("taxis_forest.rds")
print(taxis_forest)
##
## Call:
## randomForest(formula = fare_amount ~ ., data = train, importance = T)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 4
##
## Mean of squared residuals: 18.08911
## % Var explained: 85.38
plot(taxis_forest)
model_forest_results_train <- getRegressionMetrics(real = taxis_train$fare_amount,
predicted = predict(taxis_forest, taxis_train))
model_forest_results_test <- getRegressionMetrics(real = taxis_test$fare_amount,
predicted = predict(taxis_forest, taxis_test))
model_forest_results_train
## MSE RMSE MAE MedAE MSLE R2
## 1 17.93963 4.23552 2.061616 1.265373 0.05553159 0.8530127
model_forest_results_test
## MSE RMSE MAE MedAE MSLE R2
## 1 18.88592 4.345793 2.090309 1.264889 0.05627187 0.8459098
RMSE for the random forest was equal to 4.35 and R^2 to 0.85.
The previous model used 4 variables chosen at random at each split. We will try to optimize this parameter using caret package. To reduce the amount of computation, number of trees will be limited to 100, as more resulted in only minor improvement in terms of error.
parameters <- expand.grid(mtry = 2:9)
ctrl_cv5 <- trainControl(method = "cv",
number = 5)
set.seed(123)
taxis_rf_optimized <-
train(fare_amount~.,
data = taxis_train,
method = "rf",
ntree = 100,
# nodesize = 100,
tuneGrid = parameters,
trControl = ctrl_cv5,
importance = TRUE)
saveRDS(object = taxis_rf_optimized, "taxis_rf_optimized.rds")
Let’s see how this model performed.
taxis_rf_optimized <- readRDS("C:/Projects/taxi-fare-prediction-main/taxis_rf_optimized.rds")
getRegressionMetrics(real = taxis_train$fare_amount,
predicted = predict(taxis_rf_optimized, taxis_train))
## MSE RMSE MAE MedAE MSLE R2
## 1 18.72568 4.327318 2.152719 1.359189 0.05876349 0.8465722
getRegressionMetrics(real = taxis_test$fare_amount,
predicted = predict(taxis_rf_optimized, taxis_test))
## MSE RMSE MAE MedAE MSLE R2
## 1 19.72477 4.441257 2.172562 1.356417 0.05901403 0.8390656
Performance was significantly worse than for the model with default parameters, which is quite surprising. RMSE is 4.44 and, in fact, it is the same as for the single regression tree model. The reason for this could be, perhaps, the limited number of trees (100). We will therefore stay with the default random forest model.
Variable importance seem to coincide with the previous results. A significant difference is, however, the lower importance of information about the airport.
varImpPlot(taxis_forest)
variables_importance_rf <- data.frame(variable = row.names(varImp(taxis_forest)),
importance = varImp(taxis_forest)$Overall)
ggplot(data = variables_importance_rf, aes(x = reorder(variable, -importance), y = importance)) +
geom_bar(stat = "identity", fill = "steelblue2") +
labs(x = "Variable") +
theme_minimal() +
theme(panel.grid.major.x = element_blank(),
panel.grid.minor.x = element_blank(),
axis.text.x = element_text(angle = 70, hjust=1, size = 11))
The final model used in the analysis is eXtreme Gradient Boosting (XBGoost) - a popular implementation of gradient boosting methods.
Let’s see the parameters for the model. We will try to tune them.
modelLookup("xgbTree")
## model parameter label forReg forClass
## 1 xgbTree nrounds # Boosting Iterations TRUE TRUE
## 2 xgbTree max_depth Max Tree Depth TRUE TRUE
## 3 xgbTree eta Shrinkage TRUE TRUE
## 4 xgbTree gamma Minimum Loss Reduction TRUE TRUE
## 5 xgbTree colsample_bytree Subsample Ratio of Columns TRUE TRUE
## 6 xgbTree min_child_weight Minimum Sum of Instance Weight TRUE TRUE
## 7 xgbTree subsample Subsample Percentage TRUE TRUE
## probModel
## 1 TRUE
## 2 TRUE
## 3 TRUE
## 4 TRUE
## 5 TRUE
## 6 TRUE
## 7 TRUE
At the first step we will try to tune nrounds - number of boosting iterations for the chosen value of learning rate (eta), which we have set to 0.25. Colsample was chosen with the following formula: (sqrt(ncol(train)-1))/(ncol(train)-1) = 0.3. Minimum child weight was selected as 1% of all observations - that is 1 000. Subsample is going to be set at 0.8 - which is a popular starting value. Initial maximum depth is set to 6.
parameters_xgb <- expand.grid(nrounds = seq(10, 120, 10),
max_depth = c(6),
eta = c(0.25),
gamma = 0.5,
colsample_bytree = c(0.3),
min_child_weight = c(1000),
subsample = 0.8)
ctrl_cv5 <- trainControl(method = "cv", number = 5)
set.seed(123)
taxis_train_xgb <- train(fare_amount~.,
data = taxis_train,
method = "xgbTree",
trControl = ctrl_cv5,
tuneGrid = parameters_xgb)
## [23:06:00] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:00] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:00] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:00] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:00] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:00] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:00] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:00] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:00] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:00] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:00] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:02] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:02] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:02] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:02] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:02] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:02] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:02] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:02] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:02] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:02] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:02] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:04] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:04] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:04] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:04] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:04] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:04] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:04] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:04] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:04] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:04] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:04] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:06] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:06] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:06] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:06] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:06] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:06] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:06] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:06] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:06] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:06] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:06] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:08] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:08] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:08] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:08] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:08] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:08] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:08] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:08] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:08] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:08] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
## [23:06:08] WARNING: src/c_api/c_api.cc:935: `ntree_limit` is deprecated, use `iteration_range` instead.
print(taxis_train_xgb)
## eXtreme Gradient Boosting
##
## 73440 samples
## 13 predictor
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 58753, 58751, 58752, 58752, 58752
## Resampling results across tuning parameters:
##
## nrounds RMSE Rsquared MAE
## 10 5.348928 0.7924725 2.862168
## 20 4.893551 0.8072461 2.549705
## 30 4.771230 0.8151338 2.428840
## 40 4.664256 0.8221246 2.331468
## 50 4.604426 0.8264198 2.283382
## 60 4.570672 0.8289696 2.257769
## 70 4.546816 0.8307298 2.242102
## 80 4.520520 0.8325666 2.226776
## 90 4.507314 0.8335265 2.221283
## 100 4.492044 0.8346756 2.211534
## 110 4.482856 0.8353026 2.207957
## 120 4.469827 0.8362389 2.202159
##
## Tuning parameter 'max_depth' was held constant at a value of 6
## Tuning
##
## Tuning parameter 'min_child_weight' was held constant at a value of
## 1000
## Tuning parameter 'subsample' was held constant at a value of 0.8
## RMSE was used to select the optimal model using the smallest value.
## The final values used for the model were nrounds = 120, max_depth = 6, eta
## = 0.25, gamma = 0.5, colsample_bytree = 0.3, min_child_weight = 1000
## and subsample = 0.8.
getRegressionMetrics(real = taxis_train$fare_amount,
predicted = predict(taxis_train_xgb, taxis_train))
## MSE RMSE MAE MedAE MSLE R2
## 1 19.19167 4.38083 2.164836 1.326693 0.05635419 0.8427541
getRegressionMetrics(real = taxis_test$fare_amount,
predicted = predict(taxis_train_xgb, taxis_test))
## MSE RMSE MAE MedAE MSLE R2
## 1 20.65233 4.544484 2.226242 1.334519 0.05889731 0.8314976
Optimal nrounds was selected to 120. In the following steps, we will be tuning other parameters to this value. Let’s first try to tune maximum depth of the tree and minimum child weight.
parameters_xgb2 <- expand.grid(nrounds = 120,
max_depth = c(5, 6, 7, 8, 9, 10, 11),
eta = c(0.25),
gamma = 0.5,
colsample_bytree = c(0.3),
min_child_weight = seq(400, 1600, 200),
subsample = 0.8)
set.seed(123)
taxis_train_xgb2 <- train(fare_amount~.,
data = train,
method = "xgbTree",
trControl = ctrl_cv5,
tuneGrid = parameters_xgb2)
saveRDS(taxis_train_xgb2, "taxis_train_xgb2.rds")
taxis_train_xgb2 <- readRDS("taxis_train_xgb2.rds")
print(taxis_train_xgb2)
## eXtreme Gradient Boosting
##
## 73440 samples
## 13 predictor
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 58753, 58751, 58752, 58752, 58752
## Resampling results across tuning parameters:
##
## max_depth min_child_weight RMSE Rsquared MAE
## 5 400 4.378321 0.8428765 2.183055
## 5 600 4.393396 0.8418134 2.176830
## 5 800 4.455484 0.8373294 2.207530
## 5 1000 4.491529 0.8346829 2.212707
## 5 1200 4.533363 0.8315829 2.227233
## 5 1400 4.601312 0.8265283 2.250038
## 5 1600 4.671051 0.8211509 2.298953
## 6 400 4.380332 0.8426530 2.188742
## 6 600 4.412945 0.8403269 2.193471
## 6 800 4.442246 0.8382345 2.190749
## 6 1000 4.503697 0.8338061 2.221895
## 6 1200 4.501791 0.8339261 2.203712
## 6 1400 4.570766 0.8288218 2.231932
## 6 1600 4.648307 0.8229627 2.266465
## 7 400 4.379015 0.8428620 2.187642
## 7 600 4.402109 0.8412096 2.189694
## 7 800 4.450452 0.8376633 2.209561
## 7 1000 4.497801 0.8342068 2.214480
## 7 1200 4.551271 0.8302229 2.248304
## 7 1400 4.593937 0.8270942 2.248251
## 7 1600 4.630814 0.8242803 2.264869
## 8 400 4.357312 0.8443057 2.173841
## 8 600 4.370500 0.8434318 2.166970
## 8 800 4.411995 0.8404606 2.163560
## 8 1000 4.467535 0.8364812 2.193208
## 8 1200 4.539795 0.8311040 2.244381
## 8 1400 4.576130 0.8283393 2.228454
## 8 1600 4.669216 0.8212631 2.280820
## 9 400 4.362770 0.8439404 2.175983
## 9 600 4.380300 0.8427678 2.168992
## 9 800 4.431870 0.8390565 2.188729
## 9 1000 4.485379 0.8351770 2.205851
## 9 1200 4.500345 0.8340077 2.197421
## 9 1400 4.588352 0.8274704 2.239901
## 9 1600 4.628865 0.8243975 2.260001
## 10 400 4.360591 0.8441218 2.185620
## 10 600 4.400049 0.8413168 2.186332
## 10 800 4.439360 0.8384453 2.185110
## 10 1000 4.485939 0.8350798 2.209669
## 10 1200 4.558180 0.8296727 2.253659
## 10 1400 4.575747 0.8283874 2.228535
## 10 1600 4.655063 0.8224256 2.291033
## 11 400 4.334670 0.8458802 2.165196
## 11 600 4.371444 0.8433607 2.163947
## 11 800 4.467310 0.8364257 2.210921
## 11 1000 4.467500 0.8364041 2.194344
## 11 1200 4.500297 0.8340135 2.206113
## 11 1400 4.595242 0.8269794 2.239285
## 11 1600 4.637978 0.8237132 2.253726
##
## Tuning parameter 'nrounds' was held constant at a value of 120
## Tuning
## parameter 'colsample_bytree' was held constant at a value of 0.3
##
## Tuning parameter 'subsample' was held constant at a value of 0.8
## RMSE was used to select the optimal model using the smallest value.
## The final values used for the model were nrounds = 120, max_depth = 11, eta
## = 0.25, gamma = 0.5, colsample_bytree = 0.3, min_child_weight = 400
## and subsample = 0.8.
So far we have arrived at the following parameters: nrounds = 120, max_depth = 11, colsample_bytree = 0.3, min_child_weight = 400.
In the net step, we will try to chose the optimal subsample.
parameters_xgb3 <- expand.grid(nrounds = 120,
max_depth = c(11),
eta = c(0.25),
gamma = 0.5,
colsample_bytree = c(0.3),
min_child_weight = c(400),
subsample = c(0.6, 0.7, 0.75, 0.8, 0.85, 0.9))
set.seed(123)
taxis_train_xgb3 <- train(fare_amount~.,
data = train,
method = "xgbTree",
trControl = ctrl_cv5,
tuneGrid = parameters_xgb3)
saveRDS(taxis_train_xgb3, "taxis_train_xgb3.rds")
taxis_train_xgb3 <- readRDS("taxis_train_xgb3.rds")
print(taxis_train_xgb3)
## eXtreme Gradient Boosting
##
## 73181 samples
## 13 predictor
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 58546, 58544, 58544, 58545, 58545
## Resampling results across tuning parameters:
##
## subsample RMSE Rsquared MAE
## 0.60 4.154130 0.8119946 1.987128
## 0.70 4.127502 0.8144139 1.972878
## 0.75 4.117653 0.8153454 1.960745
## 0.80 4.124486 0.8146700 1.966768
## 0.85 4.110800 0.8159767 1.969309
## 0.90 4.098681 0.8170226 1.962782
##
## Tuning parameter 'nrounds' was held constant at a value of 120
## Tuning
## held constant at a value of 0.3
## Tuning parameter 'min_child_weight' was
## held constant at a value of 400
## RMSE was used to select the optimal model using the smallest value.
## The final values used for the model were nrounds = 120, max_depth = 8, eta
## = 0.25, gamma = 0.5, colsample_bytree = 0.3, min_child_weight = 400
## and subsample = 0.9.
Final parameters: nrounds = 120, max_depth = 8, eta = 0.25, gamma = 0.5, colsample_bytree = 0.3, min_child_weight = 400, subsample = 0.9.
model_results_xgboost_train <- getRegressionMetrics(real = taxis_train$fare_amount,
predicted = predict(taxis_train_xgb3, taxis_train))
## [23:06:11] WARNING: src/learner.cc:1203:
## If you are loading a serialized model (like pickle in Python, RDS in R) generated by
## older XGBoost, please export the model by calling `Booster.save_model` from that version
## first, then load it back in current version. See:
##
## https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
##
## for more details about differences between saving model and serializing.
##
## [23:06:11] WARNING: src/learner.cc:888: Found JSON model saved before XGBoost 1.6, please save the model using current version again. The support for old JSON model will be discontinued in XGBoost 2.3.
## [23:06:11] WARNING: src/learner.cc:553:
## If you are loading a serialized model (like pickle in Python, RDS in R) generated by
## older XGBoost, please export the model by calling `Booster.save_model` from that version
## first, then load it back in current version. See:
##
## https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
##
## for more details about differences between saving model and serializing.
model_results_xgboost_test <- getRegressionMetrics(real = taxis_test$fare_amount,
predicted = predict(taxis_train_xgb3, taxis_test))
model_results_xgboost_train
## MSE RMSE MAE MedAE MSLE R2
## 1 18.95723 4.35399 2.191778 1.374749 0.05883315 0.8446751
model_results_xgboost_test
## MSE RMSE MAE MedAE MSLE R2
## 1 19.14826 4.375873 2.212626 1.379736 0.05823421 0.8437693
RMSE for the XGBoost model on the test data is equal to 4.38 and R^2, which is comparable to the results of the bagged tree and random forest.
On the plot below we can see the complexity of the model. For the leaf depth of 9 both number of leafs and cover rises significantly. For our final model, max depth has been selected to 8.
#xgb.plot.tree(model = taxis_train_xgb3$finalModel, trees = 1)
xgb.plot.deepness(model = taxis_train_xgb3$finalModel, col = "steelblue")
Finally, let’s compare performance of the analyzed models. RMSE fluctuated around 4.4$, which is only moderately satisfactory, given that the mean of the fare amount was equal to 11.4$. Based on RMSE, the bagged tree model performed the best, however random forest performed very similarly. Random Forest, however, had lower MAE, Median Absolute Error and Mean Squared Logarithmic Error. All the machine learning models, however, achieved better results (based on RMSE) than simple linear regression model.
results <- cbind(
data.frame(
model = c("Linear regression",
"Random forest",
"Regression tree",
"Bagged tree",
"XGBoost")),
rbind(
model_linear_results_test,
model_forest_results_test,
model_tree_results_test,
model_bagg_results_test,
model_results_xgboost_test)
)
results %>%
arrange(RMSE)
## model MSE RMSE MAE MedAE MSLE R2
## 1 Bagged tree 18.80180 4.336104 2.369062 1.550254 0.06411911 0.8465961
## 2 Random forest 18.88592 4.345793 2.090309 1.264889 0.05627187 0.8459098
## 3 XGBoost 19.14826 4.375873 2.212626 1.379736 0.05823421 0.8437693
## 4 Regression tree 19.75310 4.444446 2.422870 1.576470 0.06753319 0.8388345
## 5 Linear regression 21.11537 4.595146 2.321552 1.535537 0.06446535 0.8277197
We have shown that machine learning models could be successfully used to predict a taxi fare with a certain level of accuracy. Moreover, for the analysed problem ensamble and gradient boosting methods performed somewhat better than regression tree and linear model. We think that the accuracy of prediction could be further improved by putting more emphasis to the feature engineering and filtering inaccurate data. It may be an interesting idea to find more detailed information about specific surcharges in the analyzed year and try to link given coordinates to additional fees. Additionally, more machine learning methods could be compared.
###Bibliography
https://www.kaggle.com/c/new-york-city-taxi-fare-prediction/overview https://www.kaggle.com/tapendrakumar09/xgboost-lgbm-dnn https://www.scirp.org/(S(lz5mqp453edsnp55rrgjct55.))/reference/referencespapers.aspx?referenceid=2763470 https://xgboost.readthedocs.io/en/stable/