We are going to use Tidymodels package, the latest developed package by Rstudio for modeling and machine learning methods.
#rm(list=ls())
aus <- read_csv("weatherAUS.csv", na = "NA", col_types = cols(
Evaporation = col_double(),
Sunshine = col_double()
))
#### Summary of the data
glimpse(aus)## Rows: 145,460
## Columns: 23
## $ Date <date> 2008-12-01, 2008-12-02, 2008-12-03, 2008-12-04, 2008...
## $ Location <chr> "Albury", "Albury", "Albury", "Albury", "Albury", "Al...
## $ MinTemp <dbl> 13.4, 7.4, 12.9, 9.2, 17.5, 14.6, 14.3, 7.7, 9.7, 13....
## $ MaxTemp <dbl> 22.9, 25.1, 25.7, 28.0, 32.3, 29.7, 25.0, 26.7, 31.9,...
## $ Rainfall <dbl> 0.6, 0.0, 0.0, 0.0, 1.0, 0.2, 0.0, 0.0, 0.0, 1.4, 0.0...
## $ Evaporation <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N...
## $ Sunshine <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N...
## $ WindGustDir <chr> "W", "WNW", "WSW", "NE", "W", "WNW", "W", "W", "NNW",...
## $ WindGustSpeed <dbl> 44, 44, 46, 24, 41, 56, 50, 35, 80, 28, 30, 31, 61, 4...
## $ WindDir9am <chr> "W", "NNW", "W", "SE", "ENE", "W", "SW", "SSE", "SE",...
## $ WindDir3pm <chr> "WNW", "WSW", "WSW", "E", "NW", "W", "W", "W", "NW", ...
## $ WindSpeed9am <dbl> 20, 4, 19, 11, 7, 19, 20, 6, 7, 15, 17, 15, 28, 24, 4...
## $ WindSpeed3pm <dbl> 24, 22, 26, 9, 20, 24, 24, 17, 28, 11, 6, 13, 28, 20,...
## $ Humidity9am <dbl> 71, 44, 38, 45, 82, 55, 49, 48, 42, 58, 48, 89, 76, 6...
## $ Humidity3pm <dbl> 22, 25, 30, 16, 33, 23, 19, 19, 9, 27, 22, 91, 93, 43...
## $ Pressure9am <dbl> 1007.7, 1010.6, 1007.6, 1017.6, 1010.8, 1009.2, 1009....
## $ Pressure3pm <dbl> 1007.1, 1007.8, 1008.7, 1012.8, 1006.0, 1005.4, 1008....
## $ Cloud9am <dbl> 8, NA, NA, NA, 7, NA, 1, NA, NA, NA, NA, 8, 8, NA, NA...
## $ Cloud3pm <dbl> NA, NA, 2, NA, 8, NA, NA, NA, NA, NA, NA, 8, 8, 7, NA...
## $ Temp9am <dbl> 16.9, 17.2, 21.0, 18.1, 17.8, 20.6, 18.1, 16.3, 18.3,...
## $ Temp3pm <dbl> 21.8, 24.3, 23.2, 26.5, 29.7, 28.9, 24.6, 25.5, 30.2,...
## $ RainToday <chr> "No", "No", "No", "No", "No", "No", "No", "No", "No",...
## $ RainTomorrow <chr> "No", "No", "No", "No", "No", "No", "No", "No", "Yes"...
## Date Location MinTemp MaxTemp
## Min. :2007-11-01 Length:145460 Min. :-8.50 Min. :-4.80
## 1st Qu.:2011-01-11 Class :character 1st Qu.: 7.60 1st Qu.:17.90
## Median :2013-06-02 Mode :character Median :12.00 Median :22.60
## Mean :2013-04-04 Mean :12.19 Mean :23.22
## 3rd Qu.:2015-06-14 3rd Qu.:16.90 3rd Qu.:28.20
## Max. :2017-06-25 Max. :33.90 Max. :48.10
## NA's :1485 NA's :1261
## Rainfall Evaporation Sunshine WindGustDir
## Min. : 0.000 Min. : 0.00 Min. : 0.00 Length:145460
## 1st Qu.: 0.000 1st Qu.: 2.60 1st Qu.: 4.80 Class :character
## Median : 0.000 Median : 4.80 Median : 8.40 Mode :character
## Mean : 2.361 Mean : 5.47 Mean : 7.61
## 3rd Qu.: 0.800 3rd Qu.: 7.40 3rd Qu.:10.60
## Max. :371.000 Max. :145.00 Max. :14.50
## NA's :3261 NA's :62790 NA's :69835
## WindGustSpeed WindDir9am WindDir3pm WindSpeed9am
## Min. : 6.00 Length:145460 Length:145460 Min. : 0.00
## 1st Qu.: 31.00 Class :character Class :character 1st Qu.: 7.00
## Median : 39.00 Mode :character Mode :character Median : 13.00
## Mean : 40.03 Mean : 14.04
## 3rd Qu.: 48.00 3rd Qu.: 19.00
## Max. :135.00 Max. :130.00
## NA's :10263 NA's :1767
## WindSpeed3pm Humidity9am Humidity3pm Pressure9am
## Min. : 0.00 Min. : 0.00 Min. : 0.00 Min. : 980.5
## 1st Qu.:13.00 1st Qu.: 57.00 1st Qu.: 37.00 1st Qu.:1012.9
## Median :19.00 Median : 70.00 Median : 52.00 Median :1017.6
## Mean :18.66 Mean : 68.88 Mean : 51.54 Mean :1017.6
## 3rd Qu.:24.00 3rd Qu.: 83.00 3rd Qu.: 66.00 3rd Qu.:1022.4
## Max. :87.00 Max. :100.00 Max. :100.00 Max. :1041.0
## NA's :3062 NA's :2654 NA's :4507 NA's :15065
## Pressure3pm Cloud9am Cloud3pm Temp9am
## Min. : 977.1 Min. :0.00 Min. :0.00 Min. :-7.20
## 1st Qu.:1010.4 1st Qu.:1.00 1st Qu.:2.00 1st Qu.:12.30
## Median :1015.2 Median :5.00 Median :5.00 Median :16.70
## Mean :1015.3 Mean :4.45 Mean :4.51 Mean :16.99
## 3rd Qu.:1020.0 3rd Qu.:7.00 3rd Qu.:7.00 3rd Qu.:21.60
## Max. :1039.6 Max. :9.00 Max. :9.00 Max. :40.20
## NA's :15028 NA's :55888 NA's :59358 NA's :1767
## Temp3pm RainToday RainTomorrow
## Min. :-5.40 Length:145460 Length:145460
## 1st Qu.:16.60 Class :character Class :character
## Median :21.10 Mode :character Mode :character
## Mean :21.68
## 3rd Qu.:26.40
## Max. :46.70
## NA's :3609
| Name | aus |
| Number of rows | 145460 |
| Number of columns | 23 |
| _______________________ | |
| Column type frequency: | |
| character | 6 |
| Date | 1 |
| numeric | 16 |
| ________________________ | |
| Group variables | None |
Variable type: character
| skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
|---|---|---|---|---|---|---|---|
| Location | 0 | 1.00 | 4 | 16 | 0 | 49 | 0 |
| WindGustDir | 10326 | 0.93 | 1 | 3 | 0 | 16 | 0 |
| WindDir9am | 10566 | 0.93 | 1 | 3 | 0 | 16 | 0 |
| WindDir3pm | 4228 | 0.97 | 1 | 3 | 0 | 16 | 0 |
| RainToday | 3261 | 0.98 | 2 | 3 | 0 | 2 | 0 |
| RainTomorrow | 3267 | 0.98 | 2 | 3 | 0 | 2 | 0 |
Variable type: Date
| skim_variable | n_missing | complete_rate | min | max | median | n_unique |
|---|---|---|---|---|---|---|
| Date | 0 | 1 | 2007-11-01 | 2017-06-25 | 2013-06-02 | 3436 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| MinTemp | 1485 | 0.99 | 12.19 | 6.40 | -8.5 | 7.6 | 12.0 | 16.9 | 33.9 | ▁▅▇▅▁ |
| MaxTemp | 1261 | 0.99 | 23.22 | 7.12 | -4.8 | 17.9 | 22.6 | 28.2 | 48.1 | ▁▂▇▅▁ |
| Rainfall | 3261 | 0.98 | 2.36 | 8.48 | 0.0 | 0.0 | 0.0 | 0.8 | 371.0 | ▇▁▁▁▁ |
| Evaporation | 62790 | 0.57 | 5.47 | 4.19 | 0.0 | 2.6 | 4.8 | 7.4 | 145.0 | ▇▁▁▁▁ |
| Sunshine | 69835 | 0.52 | 7.61 | 3.79 | 0.0 | 4.8 | 8.4 | 10.6 | 14.5 | ▃▃▅▇▃ |
| WindGustSpeed | 10263 | 0.93 | 40.04 | 13.61 | 6.0 | 31.0 | 39.0 | 48.0 | 135.0 | ▅▇▁▁▁ |
| WindSpeed9am | 1767 | 0.99 | 14.04 | 8.92 | 0.0 | 7.0 | 13.0 | 19.0 | 130.0 | ▇▁▁▁▁ |
| WindSpeed3pm | 3062 | 0.98 | 18.66 | 8.81 | 0.0 | 13.0 | 19.0 | 24.0 | 87.0 | ▇▇▁▁▁ |
| Humidity9am | 2654 | 0.98 | 68.88 | 19.03 | 0.0 | 57.0 | 70.0 | 83.0 | 100.0 | ▁▁▅▇▆ |
| Humidity3pm | 4507 | 0.97 | 51.54 | 20.80 | 0.0 | 37.0 | 52.0 | 66.0 | 100.0 | ▂▅▇▆▂ |
| Pressure9am | 15065 | 0.90 | 1017.65 | 7.11 | 980.5 | 1012.9 | 1017.6 | 1022.4 | 1041.0 | ▁▁▇▇▁ |
| Pressure3pm | 15028 | 0.90 | 1015.26 | 7.04 | 977.1 | 1010.4 | 1015.2 | 1020.0 | 1039.6 | ▁▁▇▇▁ |
| Cloud9am | 55888 | 0.62 | 4.45 | 2.89 | 0.0 | 1.0 | 5.0 | 7.0 | 9.0 | ▇▃▃▇▅ |
| Cloud3pm | 59358 | 0.59 | 4.51 | 2.72 | 0.0 | 2.0 | 5.0 | 7.0 | 9.0 | ▆▅▃▇▃ |
| Temp9am | 1767 | 0.99 | 16.99 | 6.49 | -7.2 | 12.3 | 16.7 | 21.6 | 40.2 | ▁▃▇▃▁ |
| Temp3pm | 3609 | 0.98 | 21.68 | 6.94 | -5.4 | 16.6 | 21.1 | 26.4 | 46.7 | ▁▃▇▃▁ |
The data is from various weather related data from many locations of Australia.
We are going to predict the variable RainTomorrow. So we want to predict if it is going to rain tomorrow or not based on today’s data.
### SUnshine looks important, but there are 70,000 missing values
aus %>%
select(Sunshine, RainTomorrow) %>%
drop_na() %>%
ggplot(aes(Sunshine, fill = RainTomorrow)) +
geom_density( alpha = 0.3)### Evaporation does not seem important, as there is significant overlap
aus %>%
select(Evaporation, RainTomorrow) %>%
drop_na() %>%
ggplot(aes(Evaporation, fill = RainTomorrow)) +
geom_density( alpha = 0.3)As these two variables have too many missing values.
### Missing RainTomorrow by Location
aus %>%
select(Location, RainTomorrow) %>%
count(Location, RainTomorrow) %>%
arrange(n, Location) %>%
filter(is.na(RainTomorrow)) %>%
mutate(total = sum(n)) %>%
ggplot(aes(n , fct_reorder(as.factor(Location), n))) +
geom_col(position = "dodge")### Location wise Rain Percentage
aus %>%
select(Location, RainTomorrow) %>%
count(Location, RainTomorrow) %>%
drop_na() %>%
group_by(Location) %>%
mutate(percent_Rain = n / sum(n)) %>%
ggplot(aes(percent_Rain , Location, fill = RainTomorrow)) +
geom_col() +
ylab("Location") +
xlab("% of Rainy Days")There are some locations with more rain than others.
aus %>%
mutate(mon = lubridate::month(aus$Date, label = TRUE)) %>%
count(Location, mon, RainTomorrow) %>%
drop_na() %>%
group_by(Location, mon) %>%
mutate(percent_Rain = n / sum(n)) %>%
ggplot(aes(percent_Rain , mon, fill = RainTomorrow)) +
geom_col() +
facet_wrap(~Location)Month seems to be a very important factor, when separated by Location
aus %>%
select(MinTemp, RainTomorrow) %>%
drop_na() %>%
ggplot(aes(MinTemp, fill = RainTomorrow)) +
geom_density( alpha = 0.3)aus %>%
select(MaxTemp, RainTomorrow) %>%
drop_na() %>%
ggplot(aes(MaxTemp, fill = RainTomorrow)) +
geom_density( alpha = 0.3)Mintemp & Maxtemp doesnt seem too important, by could be useful location wise.
aus %>%
select(Rainfall, RainTomorrow) %>%
drop_na() %>%
ggplot(aes(Rainfall, fill = RainTomorrow)) +
geom_density( alpha = 0.3) +
xlim(0, 5)## # A tibble: 4 x 3
## RainToday RainTomorrow n
## <chr> <chr> <int>
## 1 No No 92728
## 2 No Yes 16604
## 3 Yes No 16858
## 4 Yes Yes 14597
##
## Pearson's Chi-squared test with Yates' continuity correction
##
## data: table(aus$RainToday, aus$RainTomorrow)
## X-squared = 13799, df = 1, p-value < 2.2e-16
Raifall seems important, observing the long tail of Yes category, and chisquare test is significant
aus %>%
select(WindGustSpeed, RainTomorrow) %>%
drop_na() %>%
ggplot(aes(WindGustSpeed, fill = RainTomorrow)) +
geom_density(alpha = 0.3)aus %>%
select(WindGustDir, RainTomorrow) %>%
count(WindGustDir, RainTomorrow) %>%
drop_na() %>%
group_by(WindGustDir) %>%
mutate(percent_Rain = n / sum(n)) %>%
ggplot(aes(percent_Rain , WindGustDir, fill = RainTomorrow)) +
geom_col()WindGustSpeed & WindGustDir is important
aus %>%
select(WindDir9am, RainTomorrow) %>%
count(WindDir9am, RainTomorrow) %>%
drop_na() %>%
group_by(WindDir9am) %>%
mutate(percent_Rain = n / sum(n)) %>%
ggplot(aes(percent_Rain , WindDir9am, fill = RainTomorrow)) +
geom_col()aus %>%
select(WindDir3pm, RainTomorrow) %>%
count(WindDir3pm, RainTomorrow) %>%
drop_na() %>%
group_by(WindDir3pm) %>%
mutate(percent_Rain = n / sum(n)) %>%
ggplot(aes(percent_Rain , WindDir3pm, fill = RainTomorrow)) +
geom_col()Wind Direction is also important
### Humidity seems to have a big effect
aus %>%
select(Humidity9am, RainTomorrow) %>%
drop_na() %>%
ggplot(aes(Humidity9am, fill = RainTomorrow)) +
geom_density(alpha = 0.3)aus %>%
select(Humidity3pm, RainTomorrow) %>%
drop_na() %>%
ggplot(aes(Humidity3pm, fill = RainTomorrow)) +
geom_density(alpha = 0.3)### Pressure seems to have a moderate effect
aus %>%
select(Pressure9am, RainTomorrow) %>%
drop_na() %>%
ggplot(aes(Pressure9am, fill = RainTomorrow)) +
geom_density(alpha = 0.3)aus %>%
select(Pressure3pm, RainTomorrow) %>%
drop_na() %>%
ggplot(aes(Pressure3pm, fill = RainTomorrow)) +
geom_density(alpha = 0.3)### Temperature seems important
aus %>%
select(Temp9am, RainTomorrow) %>%
drop_na() %>%
ggplot(aes(Temp9am, fill = RainTomorrow)) +
geom_density(alpha = 0.3)aus %>%
select(Temp3pm, RainTomorrow) %>%
drop_na() %>%
ggplot(aes(Temp3pm, fill = RainTomorrow)) +
geom_density(alpha = 0.3)Here we select the variables that have some relation with the variable Rain Tomorrow. So we create a new dataset with reduced variable, this dataset will be used for building our Random Forest Model. The data had some missing data. We dropped these missing observations from the data. Also Sunshine and Evaporation variable has too many missing observations. So we will not use these two variables, otherwise we remove out almost half of our data.
aus_df <- aus %>%
select(Location,
RainTomorrow,
WindGustDir,
WindDir9am,
WindDir3pm,
RainToday,
WindGustSpeed,
MinTemp,
MaxTemp,
Rainfall,
Humidity9am,
Humidity3pm,
Pressure9am,
Pressure3pm,
Temp3pm
) %>%
mutate(Month = factor(lubridate::month(aus$Date, label = TRUE), ordered = FALSE)) %>% ### Extracted only months from the Date Variable
drop_na() %>%
mutate_if(is.character, as.factor) %>%
mutate(RainTomorrow = relevel(RainTomorrow, ref = "Yes")) ## Rows: 112,925
## Columns: 16
## $ Location <fct> Albury, Albury, Albury, Albury, Albury, Albury, Albur...
## $ RainTomorrow <fct> No, No, No, No, No, No, No, No, Yes, No, Yes, Yes, Ye...
## $ WindGustDir <fct> W, WNW, WSW, NE, W, WNW, W, W, NNW, W, N, NNE, W, SW,...
## $ WindDir9am <fct> W, NNW, W, SE, ENE, W, SW, SSE, SE, S, SSE, NE, NNW, ...
## $ WindDir3pm <fct> WNW, WSW, WSW, E, NW, W, W, W, NW, SSE, ESE, ENE, NNW...
## $ RainToday <fct> No, No, No, No, No, No, No, No, No, Yes, No, Yes, Yes...
## $ WindGustSpeed <dbl> 44, 44, 46, 24, 41, 56, 50, 35, 80, 28, 30, 31, 61, 4...
## $ MinTemp <dbl> 13.4, 7.4, 12.9, 9.2, 17.5, 14.6, 14.3, 7.7, 9.7, 13....
## $ MaxTemp <dbl> 22.9, 25.1, 25.7, 28.0, 32.3, 29.7, 25.0, 26.7, 31.9,...
## $ Rainfall <dbl> 0.6, 0.0, 0.0, 0.0, 1.0, 0.2, 0.0, 0.0, 0.0, 1.4, 0.0...
## $ Humidity9am <dbl> 71, 44, 38, 45, 82, 55, 49, 48, 42, 58, 48, 89, 76, 6...
## $ Humidity3pm <dbl> 22, 25, 30, 16, 33, 23, 19, 19, 9, 27, 22, 91, 93, 43...
## $ Pressure9am <dbl> 1007.7, 1010.6, 1007.6, 1017.6, 1010.8, 1009.2, 1009....
## $ Pressure3pm <dbl> 1007.1, 1007.8, 1008.7, 1012.8, 1006.0, 1005.4, 1008....
## $ Temp3pm <dbl> 21.8, 24.3, 23.2, 26.5, 29.7, 28.9, 24.6, 25.5, 30.2,...
## $ Month <fct> Dec, Dec, Dec, Dec, Dec, Dec, Dec, Dec, Dec, Dec, Dec...
## [1] "Yes" "No"
Our model will take “Yes” as success level of RainTomorrow variable.
We fit the model using the training data, and evaluate or model on the testing dataset.
### We will use Logistic Regression and Random Forest for Modeling ###
set.seed(123)
aus_split <- initial_split(aus_df, strata = RainTomorrow)
aus_train <- training(aus_split)
aus_test <- testing(aus_split)
aus_split # Shows Training/Testing/Total Observations## <Analysis/Assess/Total>
## <84695/28230/112925>
We create Cross Validation data sets, we use these datasets for tuning our Random Forest Model’s Hyperparameters. We use the Training set for cross validation. We create 10-folds cross validation dataset.
## # 10-fold cross-validation using stratification
## # A tibble: 10 x 2
## splits id
## <list> <chr>
## 1 <split [76.2K/8.5K]> Fold01
## 2 <split [76.2K/8.5K]> Fold02
## 3 <split [76.2K/8.5K]> Fold03
## 4 <split [76.2K/8.5K]> Fold04
## 5 <split [76.2K/8.5K]> Fold05
## 6 <split [76.2K/8.5K]> Fold06
## 7 <split [76.2K/8.5K]> Fold07
## 8 <split [76.2K/8.5K]> Fold08
## 9 <split [76.2K/8.5K]> Fold09
## 10 <split [76.2K/8.5K]> Fold10
In this step we specify the models, the engine to fit the model, and some other specifications. We will fit a logistic regression too, just to compare our RF model.
## Logistic Regression Model Specification (classification)
##
## Computational engine: glm
### Random Forest
rf_spec <- rand_forest(
mtry = tune(), ### We tune this hyperparameter via Cross Validation
trees = 500, ### We grow 500 random trees, That's going to be used for prediction
min_n = tune() ### We tune this hyperparameter via Cross Validation
) %>%
set_mode("classification") %>%
set_engine("ranger")
rf_spec## Random Forest Model Specification (classification)
##
## Main Arguments:
## mtry = tune()
## trees = 500
## min_n = tune()
##
## Computational engine: ranger
Workflow allows to specify steps of model fitting. First we specify the formula, then we can add the model in the workflow. Workflow allows us to add model specification, recipe of feature engineering, model etc in a lego block type style. This is really useful for changing any part of the model very easily.
## == Workflow ===========================================
## Preprocessor: Formula
## Model: None
##
## -- Preprocessor ---------------------------------------
## RainTomorrow ~ .
## # Resampling results
## # 10-fold cross-validation using stratification
## # A tibble: 10 x 5
## splits id .metrics .notes .predictions
## <list> <chr> <list> <list> <list>
## 1 <split [76.2K/8.5K~ Fold01 <tibble [2 x 3~ <tibble [0 x ~ <tibble [8,470 x 5~
## 2 <split [76.2K/8.5K~ Fold02 <tibble [2 x 3~ <tibble [0 x ~ <tibble [8,470 x 5~
## 3 <split [76.2K/8.5K~ Fold03 <tibble [2 x 3~ <tibble [0 x ~ <tibble [8,470 x 5~
## 4 <split [76.2K/8.5K~ Fold04 <tibble [2 x 3~ <tibble [0 x ~ <tibble [8,470 x 5~
## 5 <split [76.2K/8.5K~ Fold05 <tibble [2 x 3~ <tibble [0 x ~ <tibble [8,470 x 5~
## 6 <split [76.2K/8.5K~ Fold06 <tibble [2 x 3~ <tibble [0 x ~ <tibble [8,469 x 5~
## 7 <split [76.2K/8.5K~ Fold07 <tibble [2 x 3~ <tibble [0 x ~ <tibble [8,469 x 5~
## 8 <split [76.2K/8.5K~ Fold08 <tibble [2 x 3~ <tibble [0 x ~ <tibble [8,469 x 5~
## 9 <split [76.2K/8.5K~ Fold09 <tibble [2 x 3~ <tibble [0 x ~ <tibble [8,469 x 5~
## 10 <split [76.2K/8.5K~ Fold10 <tibble [2 x 3~ <tibble [0 x ~ <tibble [8,469 x 5~
## # A tibble: 2 x 5
## .metric .estimator mean n std_err
## <chr> <chr> <dbl> <int> <dbl>
## 1 accuracy binary 0.853 10 0.00118
## 2 roc_auc binary 0.874 10 0.00123
## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 sens binary 0.520
## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 spec binary 0.948
## # A tibble: 4 x 3
## Prediction Truth Freq
## <fct> <fct> <dbl>
## 1 Yes Yes 976.
## 2 Yes No 342.
## 3 No Yes 900.
## 4 No No 6251.
## # Resampling results
## # Monte Carlo cross-validation (0.75/0.25) with 1 resamples
## # A tibble: 1 x 6
## splits id .metrics .notes .predictions .workflow
## <list> <chr> <list> <list> <list> <list>
## 1 <split [84.7K~ train/test~ <tibble [2 ~ <tibble [0~ <tibble [28,230~ <workflo~
## # A tibble: 2 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.851
## 2 roc_auc binary 0.874
## Confusion matrix on the Test data
aus_final %>%
collect_predictions() %>%
conf_mat(RainTomorrow, .pred_class)## Truth
## Prediction Yes No
## Yes 3221 1166
## No 3033 20810
aus_final$.workflow[[1]] %>%
tidy(exponentiate = TRUE) %>%
arrange(p.value) %>%
filter(term != "(Intercept)", p.value < 0.05) %>%
knitr::kable()| term | estimate | std.error | statistic | p.value |
|---|---|---|---|---|
| WindGustSpeed | 0.9535625 | 0.0009876 | -48.148050 | 0.0000000 |
| Humidity3pm | 0.9265660 | 0.0012420 | -61.409036 | 0.0000000 |
| Pressure3pm | 1.2640594 | 0.0073163 | 32.028070 | 0.0000000 |
| Pressure9am | 0.8412601 | 0.0073046 | -23.663736 | 0.0000000 |
| LocationWollongong | 7.0066112 | 0.0963381 | 20.208562 | 0.0000000 |
| LocationNorahHead | 6.1855316 | 0.0961981 | 18.942301 | 0.0000000 |
| LocationTownsville | 7.0236202 | 0.1068185 | 18.248513 | 0.0000000 |
| LocationDarwin | 5.7341085 | 0.1033530 | 16.897745 | 0.0000000 |
| LocationNorfolkIsland | 4.7411831 | 0.0926892 | 16.790369 | 0.0000000 |
| LocationGoldCoast | 5.0003229 | 0.0981166 | 16.403977 | 0.0000000 |
| RainTodayYes | 0.6441022 | 0.0286172 | -15.371814 | 0.0000000 |
| LocationBallarat | 3.5735074 | 0.1010975 | 12.597219 | 0.0000000 |
| LocationCairns | 3.5362171 | 0.1012029 | 12.480443 | 0.0000000 |
| LocationSale | 3.1321139 | 0.0996143 | 11.461283 | 0.0000000 |
| LocationSydneyAirport | 2.8984182 | 0.0952568 | 11.171542 | 0.0000000 |
| MinTemp | 0.9485147 | 0.0048122 | -10.984231 | 0.0000000 |
| LocationCoffsHarbour | 2.8776826 | 0.0970052 | 10.896175 | 0.0000000 |
| LocationHobart | 2.7877858 | 0.0966999 | 10.602360 | 0.0000000 |
| MonthApr | 0.5749861 | 0.0558724 | -9.904883 | 0.0000000 |
| LocationMelbourneAirport | 2.6268596 | 0.0980006 | 9.854933 | 0.0000000 |
| LocationKatherine | 5.6573074 | 0.1828811 | 9.475817 | 0.0000000 |
| MonthMay | 0.5644623 | 0.0608760 | -9.394201 | 0.0000000 |
| MonthAug | 0.5344321 | 0.0669010 | -9.365346 | 0.0000000 |
| LocationWalpole | 2.3291937 | 0.0917895 | 9.211538 | 0.0000000 |
| LocationLaunceston | 2.9593517 | 0.1202747 | 9.020772 | 0.0000000 |
| LocationWoomera | 3.0095624 | 0.1314665 | 8.380800 | 0.0000000 |
| LocationMelbourne | 2.3239307 | 0.1013932 | 8.316732 | 0.0000000 |
| LocationNhil | 2.9143598 | 0.1299656 | 8.230258 | 0.0000000 |
| LocationWilliamtown | 2.2707106 | 0.1012217 | 8.101950 | 0.0000000 |
| LocationPortland | 2.0909112 | 0.0916571 | 8.047388 | 0.0000000 |
| MonthOct | 0.6440919 | 0.0576255 | -7.634011 | 0.0000000 |
| MonthJul | 0.5947566 | 0.0689688 | -7.533888 | 0.0000000 |
| LocationDartmoor | 2.0994865 | 0.0989720 | 7.493969 | 0.0000000 |
| Rainfall | 0.9898931 | 0.0013817 | -7.352037 | 0.0000000 |
| LocationSydney | 2.0583515 | 0.0989953 | 7.292324 | 0.0000000 |
| MonthJun | 0.6354007 | 0.0659861 | -6.872651 | 0.0000000 |
| MonthMar | 0.7060728 | 0.0509373 | -6.832653 | 0.0000000 |
| MonthSep | 0.6758546 | 0.0624927 | -6.269170 | 0.0000000 |
| LocationWatsonia | 1.7698373 | 0.0951843 | 5.997708 | 0.0000000 |
| LocationMountGambier | 1.7505971 | 0.0935886 | 5.983172 | 0.0000000 |
| MonthNov | 0.7315871 | 0.0531791 | -5.877098 | 0.0000000 |
| LocationMoree | 1.9346420 | 0.1131659 | 5.831461 | 0.0000000 |
| LocationCanberra | 1.8342662 | 0.1042948 | 5.816635 | 0.0000000 |
| LocationRichmond | 1.7923831 | 0.1050092 | 5.557094 | 0.0000000 |
| LocationPearceRAAF | 1.8054414 | 0.1065746 | 5.543581 | 0.0000000 |
| WindDir9amNNE | 0.6916581 | 0.0666559 | -5.530847 | 0.0000000 |
| LocationBadgerysCreek | 1.7651188 | 0.1030021 | 5.516569 | 0.0000000 |
| LocationNuriootpa | 1.7437910 | 0.1016948 | 5.467945 | 0.0000000 |
| WindDir3pmNNW | 0.6724014 | 0.0746874 | -5.314145 | 0.0000001 |
| LocationTuggeranong | 1.6928636 | 0.1063104 | 4.951740 | 0.0000007 |
| LocationBendigo | 1.6247072 | 0.1034191 | 4.692824 | 0.0000027 |
| LocationBrisbane | 1.5329184 | 0.0959964 | 4.449889 | 0.0000086 |
| LocationAliceSprings | 1.7339241 | 0.1272068 | 4.326710 | 0.0000151 |
| LocationWitchcliffe | 1.4819493 | 0.0967864 | 4.064189 | 0.0000482 |
| WindDir9amNE | 0.7620239 | 0.0680087 | -3.996215 | 0.0000644 |
| WindDir3pmNW | 0.7550289 | 0.0739640 | -3.799133 | 0.0001452 |
| WindDir9amENE | 0.7873933 | 0.0668736 | -3.574316 | 0.0003511 |
| LocationMildura | 1.4871594 | 0.1112364 | 3.567787 | 0.0003600 |
| WindGustDirNE | 1.2719242 | 0.0720104 | 3.340223 | 0.0008371 |
| WindDir3pmNE | 1.2603053 | 0.0695798 | 3.325016 | 0.0008841 |
| MonthDec | 0.8399735 | 0.0525949 | -3.315622 | 0.0009144 |
| WindDir9amWSW | 0.7965104 | 0.0690441 | -3.295215 | 0.0009835 |
| WindDir3pmSW | 1.2709539 | 0.0730009 | 3.284450 | 0.0010218 |
| WindGustDirESE | 0.8041092 | 0.0681775 | -3.197831 | 0.0013847 |
| MonthFeb | 0.8461557 | 0.0530487 | -3.149030 | 0.0016381 |
| LocationPerthAirport | 1.3549886 | 0.0967646 | 3.139507 | 0.0016923 |
| LocationUluru | 1.6490991 | 0.1593420 | 3.139343 | 0.0016933 |
| LocationWaggaWagga | 1.3741888 | 0.1028346 | 3.091018 | 0.0019947 |
| WindDir3pmN | 0.8045834 | 0.0718612 | -3.025703 | 0.0024806 |
| WindGustDirNNE | 1.2447492 | 0.0746476 | 2.932902 | 0.0033581 |
| Humidity9am | 0.9970742 | 0.0010173 | -2.880218 | 0.0039740 |
| WindDir3pmSSE | 1.2051445 | 0.0685051 | 2.723877 | 0.0064521 |
| LocationAlbury | 1.3181576 | 0.1050546 | 2.629443 | 0.0085525 |
| WindGustDirSE | 0.8391625 | 0.0670882 | -2.613737 | 0.0089558 |
| WindGustDirSSE | 0.8466773 | 0.0694731 | -2.395682 | 0.0165895 |
| LocationPerth | 1.2502756 | 0.0951999 | 2.346262 | 0.0189628 |
| WindDir9amW | 0.8537189 | 0.0680669 | -2.323496 | 0.0201525 |
| WindDir9amSSE | 1.1688668 | 0.0672034 | 2.321828 | 0.0202422 |
| WindDir9amS | 1.1610250 | 0.0675662 | 2.209732 | 0.0271238 |
| WindGustDirSW | 0.8604963 | 0.0712199 | -2.109607 | 0.0348922 |
| WindGustDirS | 0.8662776 | 0.0697163 | -2.059056 | 0.0394889 |
| LocationCobar | 1.2518127 | 0.1100051 | 2.041657 | 0.0411855 |
mtry is number of variables to randomly consider for splitting at each node. mtry number of variables are selected at random, then split the node according to the variable that produce highest PURITY. min_n is the number observation in a node to be considered to be split further. min_n is used to control overfitting.
Number of trees to build in a random forest is also a hyperparameter. But we choose it to be 500 this time. Our model takes an observation, then predicts the outcome of this observation on 500 trees, then takes the highest appearing classification among these 500 trees. That’s the idea behind RF model.
If we fit a very complex tree, it will have many terminal nodes, and it will have high variance on predictions for a new dataset.
All these hyperparameters need to be decided by experimenting on the dataset, as there is no particular rule for choosing their values.
Both mtry and min_n will be choosen via a cross validation method. First we take a combination of mtry and min_n, then fit a random forest model with each combination. And for each such random forest we evaluate the performance of each tree in our cross validation datasets, and record the performance via accuracy and AUC.
Then we can check which combination of mtry and min_n results in highest accuracy and AUC. Then we fit a model again on our whole Training dataset using this combination. and evaluate on the test dataset to check for overfitting.
Tidymodels tune_grid allows us to tune models via cross validation.
## Create a workflow for tuning random forest
## Add the random forest model to the workflow
## Then start tuning ===>>>
tune_wf <- aus_wf %>%
add_model(rf_spec)
tune_wf## == Workflow ===========================================
## Preprocessor: Formula
## Model: rand_forest()
##
## -- Preprocessor ---------------------------------------
## RainTomorrow ~ .
##
## -- Model ----------------------------------------------
## Random Forest Model Specification (classification)
##
## Main Arguments:
## mtry = tune()
## trees = 500
## min_n = tune()
##
## Computational engine: ranger
We take 25 combinations of mtry and min_n for tuning. The model will automatically decide on the values of the combinations. Also number of trees in each random forest model is set to 500 in rf_spec. These numbers could be increased, but for this time, it is okay to use lower numbers, as it takes long time to tune Random Forest.
I choose 25 grids because its slightly higher than our Number of Variables (16). So setting 25 will assure we try all combinations.
### Takes 300 mins
rf_rs_tune <- tune_grid(
object = tune_wf,
resamples = aus_cv,
grid = 25,
control = control_resamples(save_pred = TRUE)
)
### Results of Tuning
rf_rs_tune## # Tuning results
## # 10-fold cross-validation using stratification
## # A tibble: 10 x 5
## splits id .metrics .notes .predictions
## <list> <chr> <list> <list> <list>
## 1 <split [76.2K/8.5~ Fold01 <tibble [50 x ~ <tibble [0 x ~ <tibble [211,750 x ~
## 2 <split [76.2K/8.5~ Fold02 <tibble [50 x ~ <tibble [0 x ~ <tibble [211,750 x ~
## 3 <split [76.2K/8.5~ Fold03 <tibble [50 x ~ <tibble [0 x ~ <tibble [211,750 x ~
## 4 <split [76.2K/8.5~ Fold04 <tibble [50 x ~ <tibble [0 x ~ <tibble [211,750 x ~
## 5 <split [76.2K/8.5~ Fold05 <tibble [50 x ~ <tibble [0 x ~ <tibble [211,750 x ~
## 6 <split [76.2K/8.5~ Fold06 <tibble [50 x ~ <tibble [0 x ~ <tibble [211,725 x ~
## 7 <split [76.2K/8.5~ Fold07 <tibble [50 x ~ <tibble [0 x ~ <tibble [211,725 x ~
## 8 <split [76.2K/8.5~ Fold08 <tibble [50 x ~ <tibble [0 x ~ <tibble [211,725 x ~
## 9 <split [76.2K/8.5~ Fold09 <tibble [50 x ~ <tibble [0 x ~ <tibble [211,725 x ~
## 10 <split [76.2K/8.5~ Fold10 <tibble [50 x ~ <tibble [0 x ~ <tibble [211,725 x ~
## # A tibble: 50 x 8
## mtry min_n .metric .estimator mean n std_err .config
## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 2 16 accuracy binary 0.856 10 0.000951 Model01
## 2 2 16 roc_auc binary 0.888 10 0.00115 Model01
## 3 7 6 accuracy binary 0.858 10 0.000868 Model02
## 4 7 6 roc_auc binary 0.888 10 0.00111 Model02
## 5 10 34 accuracy binary 0.857 10 0.000744 Model03
## 6 10 34 roc_auc binary 0.885 10 0.00111 Model03
## 7 9 8 accuracy binary 0.858 10 0.000773 Model04
## 8 9 8 roc_auc binary 0.887 10 0.000996 Model04
## 9 6 40 accuracy binary 0.856 10 0.000848 Model05
## 10 6 40 roc_auc binary 0.887 10 0.00106 Model05
## # ... with 40 more rows
The autoplot shows combination of mtry and min_n, and shows for which values the Accuracy and AUC have highest values. This helps us to choose the BEST hyperparameter values. We can either select that gives best accuracy or the on that results in best AUC. This depends on the problem and what we are predicting and classifying.
rf_rs_tune %>%
collect_predictions() %>%
group_by(id) %>%
roc_curve(RainTomorrow, .pred_Yes) %>%
ggplot(aes(1 - specificity, sensitivity, color = id)) +
geom_abline(lty = 2, color = "gray80", size = 1.5) +
geom_path(show.legend = TRUE, alpha = 0.6, size = 1.2) +
coord_equal()ROC curves indicate that each Fold performs similar with True Positive and False Positive classification, indicating the robustness of predictions via Random Forest.
## # A tibble: 1 x 3
## mtry min_n .config
## <int> <int> <chr>
## 1 3 4 Model12
### Also check which model has the highest AUC
best_auc <- select_best(rf_rs_tune, "roc_auc")
best_auc## # A tibble: 1 x 3
## mtry min_n .config
## <int> <int> <chr>
## 1 3 4 Model12
We decide to go with accuracy this time. And finalize the model. Then fit this final model on our testing data and evaluate on training data.
## Random Forest Model Specification (classification)
##
## Main Arguments:
## mtry = 3
## trees = 500
## min_n = 4
##
## Computational engine: ranger
The last_fit fits the model on Training and evaluates on Test dataset of the split.
## # A tibble: 2 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.859
## 2 roc_auc binary 0.893
About the same as in CV. So we didnt overfit our data on Training set.
## Truth
## Prediction Yes No
## Yes 3275 995
## No 2979 20981
final_res %>%
collect_predictions() %>%
conf_mat(RainTomorrow, .pred_class) %>%
autoplot(type = "heatmap")## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 sens binary 0.524
## # A tibble: 1 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 spec binary 0.955
Fits the model multiple times, by permutation of variables. If the predictions change significantly after permutation of a variable, that variable has high impact on outcome.
library(vip)
rf_final %>%
set_engine("ranger", importance = "permutation") %>%
fit(RainTomorrow ~ . ,
data = aus_df
) %>%
vip(geom = "col")
5 Comment on the Model
5.1 Findings
5.2 Further Improvement