Data may encompass numerous variables. The inclusion of highly correlated or irrelevant variables with the outcome variable can result in overfitting, compromising the reliability of predictions. During model deployment, unnecessary variables may escalate costs associated with data collection and processing.
The curse of dimensionality refers to various challenges and issues that arise when working with high-dimensional data. As the number of features or variables in a dataset increases, the volume of the data space grows exponentially. This can lead to several problems, including increased sparsity of data, computational complexity, and the risk of overfitting in predictive models.
In high-dimensional spaces, data points become more dispersed, making it harder to find meaningful patterns. The curse of dimensionality highlights the difficulties and limitations associated with analyzing and modeling data in spaces with a large number of dimensions. Techniques like dimensionality reduction are often employed to mitigate these challenges and extract relevant information from high-dimensional datasets.
We will introduce the following dimensionality reduction methods:
Combine categories: Involves merging or grouping similar categories within categorical variables. Reduces the number of unique values, simplifying the dataset.
Use data summaries (such as correlation): Involves analyzing data summaries like correlation coefficients to identify and retain only the most relevant variables. Variables with high correlations may be redundant, and one of them can be removed.
Convert categorical data to numerical: Transforms categorical variables into numerical representations. Enables the use of numerical techniques on categorical data, facilitating analysis.
Use the principal component analysis: A dimension reduction technique that transforms the original features into a new set of uncorrelated variables (principal components). Retains most of the important information while reducing the number of dimensions.
Use supervised learning methods: Utilizes techniques like feature selection or feature extraction within a supervised learning framework. Focuses on retaining features that are most relevant for predicting the target variable.
We use the diamonds data from the tidyverse package to demonstrate the aforementioned dimension reduction techniques. The first a few observations of the data are:
## # A tibble: 6 × 10
## carat cut color clarity depth table price x y z
## <dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
## 1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43
## 2 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
## 3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
## 4 0.29 Premium I VS2 62.4 58 334 4.2 4.23 2.63
## 5 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75
## 6 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48
## # A tibble: 6 × 11
## carat cut color clarity depth table price x y z CombinedCut
## <dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl> <chr>
## 1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43 HighQuality
## 2 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31 HighQuality
## 3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31 Other
## 4 0.29 Premium I VS2 62.4 58 334 4.2 4.23 2.63 HighQuality
## 5 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75 Other
## 6 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48 Other
## tibble [53,940 × 11] (S3: tbl_df/tbl/data.frame)
## $ carat : num [1:53940] 0.23 0.21 0.23 0.29 0.31 0.24 0.24 0.26 0.22 0.23 ...
## $ cut : Ord.factor w/ 5 levels "Fair"<"Good"<..: 5 4 2 4 2 3 3 3 1 3 ...
## $ color : Ord.factor w/ 7 levels "D"<"E"<"F"<"G"<..: 2 2 2 6 7 7 6 5 2 5 ...
## $ clarity : Ord.factor w/ 8 levels "I1"<"SI2"<"SI1"<..: 2 3 5 4 2 6 7 3 4 5 ...
## $ depth : num [1:53940] 61.5 59.8 56.9 62.4 63.3 62.8 62.3 61.9 65.1 59.4 ...
## $ table : num [1:53940] 55 61 65 58 58 57 57 55 61 61 ...
## $ price : int [1:53940] 326 326 327 334 335 336 336 337 337 338 ...
## $ x : num [1:53940] 3.95 3.89 4.05 4.2 4.34 3.94 3.95 4.07 3.87 4 ...
## $ y : num [1:53940] 3.98 3.84 4.07 4.23 4.35 3.96 3.98 4.11 3.78 4.05 ...
## $ z : num [1:53940] 2.43 2.31 2.31 2.63 2.75 2.48 2.47 2.53 2.49 2.39 ...
## $ CombinedCut: chr [1:53940] "HighQuality" "HighQuality" "Other" "HighQuality" ...
## carat cut color clarity depth
## Min. :0.2000 Fair : 1610 D: 6775 SI1 :13065 Min. :43.00
## 1st Qu.:0.4000 Good : 4906 E: 9797 VS2 :12258 1st Qu.:61.00
## Median :0.7000 Very Good:12082 F: 9542 SI2 : 9194 Median :61.80
## Mean :0.7979 Premium :13791 G:11292 VS1 : 8171 Mean :61.75
## 3rd Qu.:1.0400 Ideal :21551 H: 8304 VVS2 : 5066 3rd Qu.:62.50
## Max. :5.0100 I: 5422 VVS1 : 3655 Max. :79.00
## J: 2808 (Other): 2531
## table price x y
## Min. :43.00 Min. : 326 Min. : 0.000 Min. : 0.000
## 1st Qu.:56.00 1st Qu.: 950 1st Qu.: 4.710 1st Qu.: 4.720
## Median :57.00 Median : 2401 Median : 5.700 Median : 5.710
## Mean :57.46 Mean : 3933 Mean : 5.731 Mean : 5.735
## 3rd Qu.:59.00 3rd Qu.: 5324 3rd Qu.: 6.540 3rd Qu.: 6.540
## Max. :95.00 Max. :18823 Max. :10.740 Max. :58.900
##
## z CombinedCut
## Min. : 0.000 Length:53940
## 1st Qu.: 2.910 Class :character
## Median : 3.530 Mode :character
## Mean : 3.539
## 3rd Qu.: 4.040
## Max. :31.800
##
We created a new variable CombinedCut based on the cut variable. The combined categories “Ideal” and “Premium” into “HighQuality” and labels the rest as “Other”.
correlation_matrix <- cor(diamonds[, c("carat", "depth", "table", "price")])
high_correlation_vars <- findCorrelation(correlation_matrix, cutoff = 0.8)
diamonds <- diamonds[, -high_correlation_vars]
print("After Correlation-based Reduction:")
## [1] "After Correlation-based Reduction:"
print(head(diamonds))
## # A tibble: 6 × 10
## cut color clarity depth table price x y z CombinedCut
## <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl> <chr>
## 1 Ideal E SI2 61.5 55 326 3.95 3.98 2.43 HighQuality
## 2 Premium E SI1 59.8 61 326 3.89 3.84 2.31 HighQuality
## 3 Good E VS1 56.9 65 327 4.05 4.07 2.31 Other
## 4 Premium I VS2 62.4 58 334 4.2 4.23 2.63 HighQuality
## 5 Good J SI2 63.3 58 335 4.34 4.35 2.75 Other
## 6 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48 Other
We calculated the correlation matrix for selected numeric variables (carat, depth, table, price). High-correlation variables above a cutoff of 0.8 were identified and removed.
diamonds$CutNumeric <- as.numeric(factor(diamonds$cut))
print(head(diamonds))
We created a new variable CutNumeric by converting the categorical variable cut into a numeric representation using as.numeric(factor(…)).
# Selecting numeric variables for PCA
indices = sample(nrow(diamonds), size = 1000)
diamonds_numeric <- diamonds[indices, c("carat", "depth", "table", "price")]
# Scaling the numeric variables
diamonds_scaled <- scale(diamonds_numeric)
# Performing PCA
pca_result <- prcomp(diamonds_scaled)
# Create a biplot
biplot(pca_result, cex = 0.7, col = c("blue", "red"), main = "PCA Biplot: First Two Principal Components")
abline(h=0, v=0, lty = "dashed", col = "red")
summary(pca_result)
Explanation of results: - Points in the biplot represent observations (data points). - In a biplot created from Principal Component Analysis (PCA), the arrows represent the loadings of the original variables on the principal components. - The direction of an arrow indicates the positive increase of the corresponding variable. For example, if an arrow points to the right, it suggests that an increase in that variable’s value corresponds to a positive direction in the principal component space. - The length of the arrow represents the strength of the loading for that variable on the principal component. Longer arrows indicate a stronger association between the variable and the principal component. - The angle between arrows provides information about the correlation between the variables. If arrows are close to each other, the corresponding variables are positively correlated. If they are more orthogonal (form a right angle), they are less correlated.
In summary, the biplot() function is used to visualize the relationships between observations and variables in the first two principal components obtained from PCA. The points in the plot represent both the observations and the variables, and the colors help distinguish between them. The biplot allows you to see how variables contribute to the principal components and how observations are positioned in the reduced-dimensional space.
library(caret) # This package provides many modeling tools
set.seed(100)
library(randomForest)
DF = data.frame(x1 = rnorm(100),
x2 = rnorm(100),
x3 = rnorm(100),
x4 = rnorm(100),
x5 = rnorm(100),
x6 = rnorm(100),
x7 = rnorm(100),
z = sample(c("A", "B", "C"), size = 100, replace= TRUE)
)
DF$y = 0.5 + 1*DF$x1 + 3*DF$x3 + 4*DF$x4 + 6*DF$x6 + 7*DF$x7 + 8*(DF$z=="B") + 9*(DF$z=="C") + rnorm(100)
# Select features using the Recursive Feature Elimination (RFE, also known as backward) method
ctrl <- rfeControl(functions = rfFuncs,
method = "cv",
number = 10)
result <- rfe(x = DF[ , -9],
y = DF$y,
sizes = c(1:3),
rfeControl = ctrl)
# Display the results of feature selection
result
# Extract selected features in order of importance
selected_features <- result$optVariables
selected_features
Explanation of code:
The result is consistent with the creation of y, since x2 and x5 are not included in the creation of y.
Evaluating predictive performance typically involves using metrics and visualizations to assess how well a model is performing on a given task. We will use packages caret, pROC, and gains for performance metrics and charts.
The general steps to evaluate the predictive performance of a model are:
Assuming you already have predicted values for a quantitative response based on a predictive model (let’s say a regression model for simplicity), and you want to evaluate it using a test dataset, here’s a step-by-step example:
The code is
# Produce performance measure: RMSE
mse <- mean((observed - predicted)^2)
rmse <- sqrt(mse)
# Produce performance measure: residuals plot
residuals <- observed - predicted
plot(predicted, residuals, main = "Residual Plot", xlab = "Predicted", ylab = "Residuals", col = "green")
abline(h = 0, col = "red") # Reference line at 0
These metrics and visualizations provide insights into different aspects of predictive performance. Lower values of MSE, RMSE, and MAE are desirable, indicating better model accuracy. For predictive R², closer to 1 indicates better explanatory power. Visualizations help to identify patterns, trends, and potential issues in the model predictions. Choose the metrics and visualizations based on the specific requirements and characteristics of your regression model and dataset.
Let’s look at an example with the mtcars data from base R. We first create a data frame using only 3 columns.
# For simplicity, select two predictors and a response variable
mtcars_subset <- mtcars[ , c("mpg", "wt", "hp")]
The above code selects specific columns (“mpg”, “wt”, and “hp”) from the mtcars dataset to create a subset named mtcars_subset.
# Set seed for reproducibility. The seed can be any positive integer.
set.seed(123)
# Determine the number of rows in the dataset
n <- nrow(mtcars_subset)
# Generate random indices for the training set (80%)
train_indices <- sample(1:n, 0.8 * n)
# Create the training set using the sampled indices
train_data <- mtcars_subset[train_indices, ]
# Create the holdout set by excluding the training indices
holdout_data <- mtcars_subset[-train_indices, ]
Train the model:
# Fit a linear regression model on the training data, with mpg the response
lm_model <- lm(formula = mpg ~ wt + hp, data = train_data)
# Predict on both training and holdout data
train_predicted <- predict(lm_model, newdata = train_data)
holdout_predicted <- predict(lm_model, newdata = holdout_data)
The above code predicts the response variable for both the training and holdout datasets using the linear regression model (lm_model) that was previously fitted to the training data. The predicted values are stored in train_predicted and holdout_predicted,
# Assess predictive performance
rmse_train <- sqrt(mean((train_data$mpg - train_predicted)^2))
rmse_holdout <- sqrt(mean((holdout_data$mpg - holdout_predicted)^2))
# Print MSE for training and holdout data
cat("RMSE on Training Data:", rmse_train, "\n")
cat("RMSE on Holdout Data:", rmse_holdout, "\n")
# Visualize observed vs. predicted on both training and holdout data
plot(holdout_data$mpg, holdout_predicted, main = "Observed vs. Predicted",
xlab = "Observed", ylab = "Predicted", col = "blue", pch = 16)
# Add points for observed vs. predicted on training data
points(train_data$mpg, train_predicted, col = "green", pch = 16)
# Add reference line for perfect predictions
abline(a = 0, b = 1, col = "red")
# Create a legend
legend("bottomright",
legend = c("Holdout Data", "Training Data"),
col = c("blue", "green"),
pch = 16, cex = 1, bty = "n")
Explanation of the code:
This code provides a clear illustration of how the model performs on both the training and holdout datasets.
When evaluating the predictive performance of a model for classifying a categorical response, there are several metrics and techniques you can use. We will focus on the binary classification scenario.
Sensitivity, specificity, and F1 score are common metrics used in binary classification to evaluate the performance of a model. These metrics are derived from the confusion matrix, which summarizes the classification results.
A confusion matrix may look like:
predicted_classes = c(0,0,1,0,0,0,1,1,0,1,0,0,1,1)
predicted_classes = factor(predicted_classes)
actual_classes = c(0,1,0,0,0,0,0,1,1,1,0,0,1,0)
actual_classes = factor(actual_classes)
# Create a confusion matrix using the confusionMatrix() from caret
conf_matrix <- confusionMatrix(predicted_classes, actual_classes, positive = "1")
conf_matrix$table
In above code, you are instructing the confusionMatrix() function to treat the class labeled as “1” as the positive class. This means that in the confusion matrix, the “1” class will be considered the positive outcome, and performance metrics will be calculated based on the classification of instances as “1” (the positive class) or not “1” (the negative class).
Consequently, metrics such as sensitivity (true positive rate), specificity (true negative rate), precision, recall, and F1-score will be computed with respect to the “1” class.
0 | 1 | |
---|---|---|
0 | True Negative | FALSE Negative |
1 | FALSE Negative | True Positive |
Let’s break down each metric:
The Sensitivity (True Positive Rate or Recall) measures the proportion of actual positives that are correctly identified by the model. It is calculated as the ratio of true positives to the sum of true positives and false negatives. Formula: Sensitivity = True Positives / (True Positives + False Negatives) A high sensitivity indicates that the model is good at correctly identifying the positive class.
The Specificity (True Negative Rate) measures the proportion of actual negatives that are correctly identified by the model. It is calculated as the ratio of true negatives to the sum of true negatives and false positives. Formula: Specificity = True Negatives / (True Negatives + False Positives) A high specificity indicates that the model is good at correctly identifying the negative class.
The F1 score is the harmonic mean of precision and recall (sensitivity). The harmonic mean of a set of numbers is defined as the reciprocal of the mean of the reciprocals of those numbers. It provides a balance between precision and recall and is especially useful when there is an imbalance between the classes.
Formula: \[F_1 = \frac{2}{1/Precision + 1/Sensitivity} \] or,
\[F_1 = 2 * (Precision * Sensitivity) / (Precision + Sensitivity)\]
Precision (Positive Predictive Value) is the ratio of true positives to the sum of true positives and false positives.
A high F1 score indicates a good balance between precision and recall.
In summary:
These metrics are important for understanding how well a binary classification model performs, especially in situations where the class distribution is imbalanced. Depending on the specific goals and requirements of your application, you might prioritize specificity, sensitivity, or a balance between precision and recall (F1 score).
Consider the famous “iris” dataset for a simple binary classification example. In this case, we’ll consider the task of classifying whether an iris flower is of the species “setosa” or not.
# Create a binary response variable: 1 if species is "setosa", 0 otherwise
iris$binary_response <- ifelse(iris$Species == "setosa", 1, 0)
The above code snippet creates a new binary response variable named “binary_response” in the iris dataset. The variable takes the value 1 if the species is “setosa” and 0 otherwise.
# Set seed for reproducibility
set.seed(123)
# Determine the number of rows in the iris dataset
n <- nrow(iris)
# Generate random indices for the training set (80%)
train_indices <- sample(1:n, 0.8 * n)
# Create the training set using the sampled indices
train_data <- iris[train_indices, ]
# Create the holdout set by excluding the training indices
holdout_data <- iris[-train_indices, ]
The above code:
# Fit a logistic regression model on the training data
logistic_model <- glm(binary_response ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width,
data = train_data,
family = "binomial")
The above code fits a logistic regression model to predict the binary response variable (binary_response) using the predictor variables Sepal.Length, Sepal.Width, Petal.Length, and Petal.Width.
# Make predictions on the holdout data
predicted_probs <- predict(logistic_model,
newdata = holdout_data,
type = "response")
# Convert predicted probabilities to class predictions (binary: 1 for predicted probability > 0.5)
predicted_classes <- ifelse(predicted_probs > 0.5, 1, 0)
Explanation:
# Convert predicted classes to factor
predicted_classes <- factor(predicted_classes)
# Convert actual classes to factor
actual_classes <- holdout_data$binary_response
actual_classes <- factor(actual_classes)
# Create a confusion matrix
conf_matrix <- confusionMatrix(predicted_classes, actual_classes, positive = "1")
# Print the confusion matrix
conf_matrix
# Extract and print performance metrics
conf_matrix$byClass
Explanation of code:
The Receiver Operating Characteristic (ROC) curve is a graphical representation of the performance of a binary classification model across different discrimination thresholds. It is widely used in machine learning and statistics to visualize the trade-off between true positive rate (sensitivity) and false positive rate (1 - specificity) as the discrimination threshold varies.
Here are the key components of an ROC curve:
A typical ROC curve is plotted with sensitivity on the y-axis and 1 - specificity on the x-axis. The curve illustrates the trade-off between correctly identifying positive instances and incorrectly classifying negative instances as positive. The diagonal line (a straight line from (0,0) to (1,1)) represents the performance of a random classifier.
In summary, the ROC curve provides a visual representation of the model’s ability to discriminate between positive and negative instances across different threshold values. It is a valuable tool for understanding the performance of binary classification models and comparing different models.
library(ggplot2)
# The actual responses
y_actual = c(0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0)
# The predicted responses
predicted_probs = c(0.43, 0.23, 0.18, 1, 0.41, 0.49, 0.43, 0.31, 0.6, 0.39, 0.63, 0.48, 0.23, 0.28, 0.45, 0.01, 0.82, 0.42, 0.29, 0.77, 0.99, 0.32, 0.31, 0.11, 0.49, 0.3, 0.33, 0.13, 0.5, 0.39, 0.37, 0.05, 0.02, 0.42, 0.95, 0.39, 0.71, 0.34, 0.99, 0.33, 0.79, 0.39, 0.51, 0.17, 0.47, 0.64, 0.11, 0.49, 0.42, 0.56, 0.4, 0.71, 0.21, 0.29, 0.33, 0.12, 0.34, 0.05, 0.18, 0.14, 0.54, 0.58, 0.29, 0.51, 0.57, 0.27, 0.04, 0.57, 0.47, 0.57, 0.41, 0.9, 0.56, 0.25, 0.25, 0.93, 0.77, 0.48, 0.74, 0.94, 0.19, 0.47, 0.58, 0.33, 0.92, 0.12, 0.84, 0.47, 0.61, 0.36, 0.46, 0.25, 0.07, 0.95, 0.53, 0.73, 0.07, 0.55, 0.56, 0.51)
# Load necessary library
library(pROC)
# Compute ROC curve
roc_curve <- roc(y_actual, predicted_probs)
# Plot ROC curve
plot(roc_curve, main = "Receiver Operating Characteristic (ROC) Curve",
col = "blue", lwd = 2, print.auc = TRUE, print.auc.x = 0.5, print.auc.y = 0.2)
The above code does the following:
y_actual <- c(0, 0, 0, 1, 0, …): These lines define a vector y_actual containing the true labels (0 or 1) corresponding to the predicted probabilities.
predicted_probs <- c(0.43, 0.23, 0.18, 1, 0.41, …): These lines define a vector predicted_probs containing the predicted probabilities for the positive class from the binary classification model.
library(pROC): This line loads the pROC library, which provides functions for computing and plotting Receiver Operating Characteristic (ROC) curves.
roc_curve <- roc(y_actual, predicted_probs): This line calculates the ROC curve based on the true labels (y_actual) and predicted probabilities (predicted_probs) using the roc() function from the pROC library.
plot(roc_curve, …): This line plots the ROC curve using the plot() function. The parameters specify the main title (main), line color (col), line width (lwd), and whether to print the AUC (print.auc) along with its position (print.auc.x and print.auc.y).
Cumulative Gains Charts, also known as Lift Charts or Gain Charts, are graphical representations used to evaluate the performance of predictive models, especially binary classification models. They help in understanding how much better a model performs compared to a random classifier or baseline.
library(gains)
gain = gains(y_actual, predicted_probs)
plot(gain$cume.obs, gain$cume.pct.of.total*sum(y_actual),
type = "l",
xlab = "# Observations",
ylab = "Cumulative Gains")
In the provided code,
Interpreting a Cumulative Gains Chart:
A decile-wise lift chart is a graphical representation used to evaluate the performance of a predictive model, typically in the context of binary classification problems. The lift chart is divided into ten segments, or deciles, representing ten equal-sized groups of observations. Each decile corresponds to a range of predicted probabilities, and the chart helps assess how well the model distinguishes between positive and negative outcomes across these deciles.
Here’s a step-by-step explanation of how to create a decile-wise lift chart-
Model Training: Train your predictive model suitable for binary classification.
Prediction: Use the trained model to make predictions on a new dataset. The output of the model is typically probability scores one for each observation.
Decile Assignment: Rank the predicted probabilities in descending order and divide the dataset into ten equal-sized groups (deciles). The first decile contains the observations with the highest predicted probabilities, the second decile contains the next set of observations, and so on.
Calculate Cumulative Response Rates: Calculate the cumulative response rate for each decile. The cumulative response rate is the cumulative number of positive outcomes (e.g., events) divided by the total number of positive outcomes in the dataset.
Calculate Expected Cumulative Response Rates: Assuming a random distribution, calculate the expected cumulative response rates for each decile. In a balanced dataset, this would be a straight line representing the overall positive outcome rate.
Plot the Lift Chart: Create a line chart where the x-axis represents the deciles, and the y-axis represents the cumulative response rates. Plot both the actual cumulative response rates and the expected cumulative response rates.
Interpretation of the lift chart:
The lift chart visually shows how well the model performs across different deciles compared to a random model. A higher lift indicates that the model is better at identifying positive outcomes, while a lower lift suggests poorer performance.
In summary, a decile-wise lift chart helps you understand how well your predictive model is at differentiating between positive and negative outcomes across different segments of the dataset, providing insights into its discriminatory power.
Example:
Assume we have the following predicted probabilities based on a binary classification model such as logistic regression. The actual binary response values are observed.
actual = c(0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0)
predicted_probability = c(0.43, 0.23, 0.18, 1, 0.41, 0.49, 0.43, 0.31, 0.6, 0.39, 0.63, 0.48, 0.23, 0.28, 0.45, 0.01, 0.82, 0.42, 0.29, 0.77, 0.99, 0.32, 0.31, 0.11, 0.49, 0.3, 0.33, 0.13, 0.5, 0.39, 0.37, 0.05, 0.02, 0.42, 0.95, 0.39, 0.71, 0.34, 0.99, 0.33, 0.79, 0.39, 0.51, 0.17, 0.47, 0.64, 0.11, 0.49, 0.42, 0.56, 0.4, 0.71, 0.21, 0.29, 0.33, 0.12, 0.34, 0.05, 0.18, 0.14, 0.54, 0.58, 0.29, 0.51, 0.57, 0.27, 0.04, 0.57, 0.47, 0.57, 0.41, 0.9, 0.56, 0.25, 0.25, 0.93, 0.77, 0.48, 0.74, 0.94, 0.19, 0.47, 0.58, 0.33, 0.92, 0.12, 0.84, 0.47, 0.61, 0.36, 0.46, 0.25, 0.07, 0.95, 0.53, 0.73, 0.07, 0.55, 0.56, 0.51)
The actual values contain 32% of one’s.
We need to use the gains() function from the gains package to create a table which includes lift values (mean responses divided by the overall mean response) multiplied by 100. The returned values from the gains() function has an element called “lift”. If the lifts are divided by 100 and then a bar chart is made with these values as heights and labeled by 10, 20, …, 100, a decile-wise lift chart is created. A lift greater than 1 indicates that the model is outperforming random chance, while a lift less than 1 suggests that the model is performing worse than random chance. Higher lift values in the early deciles indicate that the model is particularly effective at identifying positive outcomes among the top predictions. A lift chart helps identify where the model provides the most benefit compared to random chance.
Let’s create a decile-wise lift chart using the above data:
library(gains)
gain = gains(y_actual, predicted_probs)
barplot(gain$lift/100, names = gain$depth, xlab = "Percentile", ylab = "Lift")
Interpretation:
The lift is more than 3.12 at the first decile, it means that the model is achieving more than 3 times as many positive outcomes as compared to what would be expected randomly (imagine flipping a loaded coin with heads 32% of the time) in that decile.
Multiple Linear Regression is a statistical method used to model the relationship between a dependent variable and two or more independent variables. The equation for a multiple linear regression model is:
\[y = \beta_0+\beta_1\cdot x_1+\beta_2\cdot x_2+\cdots +\beta_p\cdot x_p+\epsilon\] where \(\beta\)’s are the coefficients representing the impact of each independent variable, \(x\)’s are the independent variables, and \(\epsilon\) is the error term.
The coefficients are estimated using the method of least squares to minimize the sum of squared differences between the observed and predicted values.
Multiple linear regression can be used for both explanatory analysis in traditional statistical applications and predictive modeling in machine learning, but the emphasis and goals can differ between the two approaches. Let’s discuss the main differences:
In traditional statistical applications, the primary goal of multiple linear regression is often to understand the relationships between variables and explain the variability in the response variable. The emphasis is on interpreting the coefficients of the independent variables (or called explanatory variables). Each coefficient represents the change in the mean of the response variable for a one-unit change in the corresponding independent variable, holding other variables constant. Traditional applications often involve rigorous checks for assumptions such as linearity, independence, homoscedasticity (i.e., equal variance), and normality of residuals. Violations of these assumptions may impact the validity of the interpretation. Variables are often selected based on statistical significance and the strength of their relationship with the response variable. Models are often kept relatively simple, focusing on the most relevant variables and avoiding overfitting. Hypothesis testing is commonly used to determine the significance of individual coefficients and the overall model fit.
In machine learning, the primary goal of multiple linear regression is often to accurately predict the response variable for new, unseen data. The emphasis is on building a model that accurately predicts the outcome for new observations. Interpretability of coefficients may be less important than predictive performance. While assumptions are still important, machine learning applications may be more forgiving of certain violations, and there is often a greater focus on predictive performance. Variables may be selected based on predictive performance metrics, such as mean squared error or R-squared, rather than strict statistical significance. Machine learning models can be more complex, allowing for interactions and non-linear relationships. Techniques such as regularization are often used to prevent overfitting. Evaluation is often based on accuracy metrics (e.g., root mean squared error) rather than traditional statistical measures. Cross-validation is commonly used to assess how well the model generalizes to new data.
In summary, while the fundamental principles of multiple linear regression remain the same, the emphasis and goals differ between explanatory analysis in traditional statistical applications and prediction in machine learning. Traditional applications focus on explaining relationships, interpreting coefficients, and adhering to strict assumptions, while machine learning applications prioritize predictive accuracy, flexibility, and generalization to new data.
We use the following example to show how to train a multiple linear regression model for predictive purposes. We use simulated data so that you know the true model and how the model training process works.
The simulated data, where the response is based on a multiple linear regression model with 3 independent variables (called predictors):
# Set seed for reproducibility
set.seed(123)
# Create simulated data
simulated_data <- data.frame(
Predictor1 = rnorm(2000),
Predictor2 = rnorm(2000),
Predictor3 = rnorm(2000)
)
# Add an extra column "Target" to the data frame
simulated_data$Target = 2 * simulated_data$Predictor1 - 3 * simulated_data$Predictor2 +
0.5 * simulated_data$Predictor3 + rnorm(2000)
Breakdown of the above code:
Next, we split the data into training and holdout sets:
# Split the data into training and holdout sets
set.seed(456) # Reset seed for reproducibility
n = nrow(simulated_data)
train_indices <- sample(1:n, size = round(0.8 * n))
train_data <- simulated_data[train_indices, ]
holdout_data <- simulated_data[-train_indices, ]
Explanation of code:
We next fit a regression model on the training data using the lm() function:
# Fit multiple regression model on training data
model <- lm(Target ~ Predictor1 + Predictor2 + Predictor3, data = train_data)
To make prediction on the holdout_data, we use the predict() function:
# Make predictions on holdout data
predictions <- predict(model, newdata = holdout_data)
Finally, we calculating the performance metric RMSE:
# Calculate RMSE
rmse <- sqrt(mean((predictions - holdout_data$Target)^2))
rmse
Explanation of code:
The RMSE (Root Mean Squared Error) is calculated by taking the square root of the mean of the squared differences between the predicted values (predictions) and the actual target values (holdout_data$Target).
Let’s calculate the RMSE based on the training data and compare it to the RMSE calculated on the holdout data.
# Make predictions on the training data
train_predictions <- predict(model, newdata = train_data)
# Calculate RMSE on the training data
train_rmse <- sqrt(mean((train_predictions - train_data$Target)^2))
# Print RMSE for comparison
cat("Holdout Data RMSE:", rmse, "\n")
cat("Training Data RMSE:", train_rmse, "\n")
In practice, it’s common to evaluate a model’s performance on both a holdout set (or validation set) and the training set. The RMSE on the training set gives you an idea of how well the model fits the data it was trained on, while the RMSE on the holdout set provides an estimate of how well the model generalizes to new, unseen data. Comparing these two values helps assess potential overfitting. Overfitting and underfitting are common challenges in machine learning, particularly when building predictive models.
Overfitting occurs when a model is too complex and captures noise or random fluctuations in the training data as if they were real patterns. A sign of overfitting is that performance on the training set is high, but performance on new data is poor. A common cause is too much model complexity and insufficient regularization. To address overfitting, use simpler models, apply regularization techniques (e.g., L1 or L2 regularization, to be covered later), or increase the amount of training data.
Underfitting occurs when a model is too simple to capture the underlying patterns in the training data. A sign of underfitting is poor performance on both the training set and new data. To address underfitting, use more complex models.
The above discussion involves a tradeoff between bias (due to underfitting) and variance (due to overfitting). A good model finds the right balance between bias and variance to generalize well to new data.
We will introduce cross-validation to evaluate model performance on multiple subsets of the data.
We can analyze learning curves to understand how model performance changes with the size of the training set.
We use regularization techniques to penalize overly complex models.
Example of creating learning curves:
# Assuming you have a dataset named 'simulated_data' with features and target
set.seed(123)
# Create an empty vector to store performance metrics
# Determine the total number of observations in the dataset
n <- nrow(simulated_data)
# Define a sequence of training set sizes, ranging from 50% to 80% of the total observations
training_set_sizes <- seq(round(0.5 * n), round(0.8 * n), by = 1)
# Determine the number of training set sizes
k <- length(training_set_sizes)
# Create empty vectors to store performance metrics for training and validation sets
train_performance <- c()
validation_performance <- c()
Explanation of code:
# Loop through different training set sizes
for (i in training_set_sizes) {
# Subset the training data
indices = sample(1:n, size = i)
subset_train_data <- simulated_data[indices, ]
# Use the remaining data as the validation set
subset_validation_data <- simulated_data[-indices, ]
# Fit the model
model <- lm(Target ~ Predictor1 + Predictor2 + Predictor3, data = subset_train_data)
# Make predictions on the training set
train_predictions <- predict(model, newdata = subset_train_data)
train_rmse <- sqrt(mean((train_predictions - subset_train_data$Target)^2))
train_performance <- c(train_performance, train_rmse)
# Make predictions on the validation set
validation_predictions <- predict(model, newdata = subset_validation_data)
validation_rmse <- sqrt(mean((validation_predictions - subset_validation_data$Target)^2))
validation_performance <- c(validation_performance, validation_rmse)
}
Explanation of code:
# Plot the learning curve
plot(training_set_sizes, train_performance, type = "l", col = "blue", ylim = c(0, max(c(train_performance, validation_performance))), xlab = "Training Set Size", ylab = "RMSE", main = "Learning Curve", lwd = 2)
lines(training_set_sizes, validation_performance, type = "l", col = "red", lwd = 2)
legend("bottomright", legend = c("Training Set", "Validation Set"), col = c("blue", "red"), lwd = 2, bty = "n")
Explanation of code:
plot: Creates the initial plot with the training set sizes on the x-axis and training set RMSE on the y-axis. lines: Adds a line for the validation set RMSE. legend: Adds a legend to the plot.
We select a best subset of predictors by assessing model performance using AIC. There are 3 ways: backward elimination, forward selection, and stepwise regression.
# Define the training control with no resampling
train_control <- trainControl(method = "none")
# Train multiple regression model
model <- train(
Target ~ Predictor1+Predictor2+Predictor3+Predictor1:Predictor2+Predictor1:Predictor3+Predictor2:Predictor3+Predictor1^2+Predictor2^2+Predictor3^2,
data = train_data,
method = "glmStepAIC", # Basic multiple linear regression
trControl = train_control,
direction = "both",
preProcess = c("center", "scale") # Standardize predictors
)
Explanation of code:
Explanation of result:
As the result shows the best model is the one with all variables.
Lasso, which stands for Least Absolute Shrinkage and Selection Operator, is a regression analysis method used for variable selection and regularization. It is particularly useful when dealing with datasets that have a large number of predictors (features) compared to the number of observations. Lasso regression helps prevent overfitting and simplifies the model by shrinking the coefficients of some predictors toward zero, effectively performing variable selection.
The lasso method adds a penalty term to the traditional linear regression cost function. The cost function for lasso regression is:
\[Cost=RSS+\lambda\cdot \sum_{j=1}^{p} |\beta_j|\]
where:
Lasso regression is particularly valuable in situations where there are many irrelevant or redundant predictors, as it tends to select a subset of the most important predictors. It’s commonly used in machine learning and statistics for feature selection and regularization. The glmnet package in R is a popular tool for fitting lasso regression models.
Here’s an example:
# Define the training control
train_control <- trainControl(method = "cv", number = 5) # 5-fold cross-validation
# Define the grid of hyperparameters for tuning (alpha and lambda)
tune_grid <- expand.grid(alpha = 1,
lambda = seq(0.001, 1, length = 100))
# Train lasso regression model with interactions and square terms
lasso_model <- train(Target ~ Predictor1+Predictor2+Predictor3+Predictor1:Predictor2+Predictor1:Predictor3+Predictor2:Predictor3+Predictor1^2+Predictor2^2+Predictor3^2,
data = train_data,
method = "glmnet",
trControl = train_control,
tuneGrid = tune_grid)
# Display the results with the optimal lambda value used
coef(lasso_model$finalModel, s = lasso_model$bestTune$lambda)
Explanation od code:
trainControl: Sets up the training control for the train function, specifying 5-fold cross-validation (method = “cv” and number = 5).
tune_grid: Defines a grid of hyperparameters for tuning. It includes alpha = 1 (indicating Lasso regularization) and a range of lambda values (seq(0.001, 1, length = 100)).
train: Uses the train function to train a Lasso regression model on the specified dataset (train_data) with the specified formula (Target ~ Predictor1+Predictor2+Predictor3+Predictor1:Predictor2+Predictor1:Predictor3+Predictor2:Predictor3+Predictor12+Predictor22+Predictor3^2), using the glmnet method. It uses the specified training control and the grid of hyperparameters for tuning.
coef(lasso_model$finalModel, s = lasso_model$bestTune$lambda): Extracts the coefficients for the selected lambda value from the final model. lasso_model$finalModel represents the final fitted model, and lasso_model$bestTune$lambda represents the lambda value selected during the tuning process.
This code performs Lasso regression with 5-fold cross-validation, tunes the model over a grid of alpha and lambda values, and then extracts the coefficients for the selected lambda value. The s parameter in coef is set to the best-tuned lambda value obtained from lasso_model$bestTune$lambda.
The following gives the performance metrics of the final model on the training and holdout data:
rbind(Training = mlba::regressionSummary(predict(lasso_model, train_data), train_data$Target),
Houldout = mlba::regressionSummary(predict(lasso_model, holdout_data), holdout_data$Target))
Note: the mlba package can be installed by using the following code:
# Install "mlba" package which gives access to all datasets and
# some utility functions such as "regressionSummary" in book "Machine Learning for Business Analytics"
if (!require(mlba)) {
library(devtools)
install_github("gedeck/mlba/mlba", force=TRUE)
}
Logistic regression is a statistical method that can be used for both explanatory analysis and prediction, depending on the context and goals of the analysis.
The general logistic regression model for binary response (0/1) is
\[ log(\frac{p}{1-p})=\beta_0+\beta_1\cdot x_1+\beta_2\cdot x_2+\cdots+\beta_k\cdot x_k\] or,
\[p=\frac{e^{\beta_0+\beta_1\cdot x_1+\beta_2\cdot x_2+\cdots+\beta_k\cdot x_k}}{1+e^{\beta_0+\beta_1\cdot x_1+\beta_2\cdot x_2+\cdots+\beta_k\cdot x_k}}\]
\[p=\frac{1}{1+e^{-(\beta_0+\beta_1\cdot x_1+\beta_2\cdot x_2+\cdots+\beta_k\cdot x_k)}}\] where \(p=P(y=1|x_1, x_2, ..., x_k)\) and \(\frac{p}{1-p}\) is the odds and \(log(\frac{p}{1-p})\) is the log-odds.
Logistic regression can be used to model the relationship between a binary outcome variable (dependent variable) and one or more predictor variables (independent variables).
Coefficients in logistic regression represent the change in the log-odds of the outcome associated with a one-unit change in the predictor variable. This allows for the interpretation of the impact of predictors on the likelihood of the outcome.
Statistical tests on coefficients help assess the significance of predictors and understand whether they have a statistically significant impact on the outcome.
Logistic regression allows for the control of confounding variables, enabling the isolation of the relationship between specific predictors and the outcome.
We use the data frame “UniversalBank” from the package “mlba” to show an example. To install the package, do
if (!require(mlba)) {
library(devtools) # Install this package first
install_github("gedeck/mlba/mlba", force=TRUE)
}
For a documentation about the data, type “?UniversalBank” in the console.
head(mlba::UniversalBank)
# Fit logistic regression model for explanatory analysis
explanatory_model <- glm(Personal.Loan ~ . - ID - ZIP.Code,
data = mlba::UniversalBank,
family = "binomial")
# Display model summary
summary(explanatory_model)
Explanation of code:
Explanation of results:
All predictors show significance at either the 5% or even the 1% level, with the exception of Age, Experience, and Mortgage.
The coefficient (\(\beta\)) of a predictor represents the change in the log-odds of the outcome for a one-unit increase in this predictor, holding other predictors constant. For a positive coefficient, an increase in the predictor variable is associated with an increase in the log-odds of the outcome. For a negative coefficient, an increase in the predictor variable is associated with a decrease in the log-odds of the outcome.
The odds ratio (OR) is often used to interpret coefficients more intuitively. It is the exponentiation of the coefficient: \(e^\beta\). For \(OR > 1\), \(\beta>0\), and an increase in a predictor variable is associated with higher odds of the outcome. For \(OR < 1\), \(\beta<0\), and an increase in the predictor variable is associated with lower odds of the outcome. For \(OR = 1\), \(\beta=0\), and the predictor variable has no on the odds.
For example, the coefficient for Income is 0.05458. The interpretation is that for a one-unit increase ($1000) in Income, holding all other predictors constant, the odds of the outcome (\(Personal.Loan=1\)) increase by a factor of approximately \(e^0.05458\) or 1.056, which means the odds increase by 5.6%. The coefficient for CD.Account is 3.823. The interpretation is that the odds that a customer who has a CD account will accept the loan offer are \(e^3.823\) or 45.74 times that of a customer who does not have a CD account, holding all other predictors constant.
To check the goodness of fit (GOF) of the logistic model, we can plot the deviance residuals:
# Obtain the deviance from the summary
model_summary <- summary(explanatory_model)
deviance <- model_summary$deviance
# Assess residuals
residuals <- residuals(explanatory_model, type = "deviance")
plot(residuals, pch = 16, main = "Deviance Residuals")
abline(h = 0, col = "red", lty = 2)
Interpretation of plot:
In the residual plot, look for patterns or trends. The red dashed line represents a residual of 0. A well-fitted model would have residuals evenly distributed around 0.
The model seemed ok.
Logistic regression is commonly used for classification problems where the outcome variable is dichotomous (e.g., yes/no, 1/0).
Logistic regression provides probabilities that an observation belongs to a particular class. The logistic function transforms the linear combination of predictors into probabilities between 0 and 1.
For binary classification, predictions can be made by setting a decision threshold (e.g., 0.5). Observations with predicted probabilities above the threshold are classified as one category, and those below the threshold are classified as the other.
Model performance is assessed using metrics such as accuracy, precision, recall, F1 score, and ROC-AUC, depending on the specific goals and characteristics of the problem.
Some considerations:
Overfitting: Regularization techniques, such as L1 (LASSO) or L2 (Ridge) regularization, can be employed to prevent overfitting, especially when dealing with a large number of predictors.
Data Quality: Logistic regression is sensitive to outliers and multicollinearity, so data preprocessing steps such as outlier removal and addressing multicollinearity should be considered.
Cross-Validation: Cross-validation helps assess the generalization performance of the model and ensures its robustness on unseen data.
In summary, logistic regression is a versatile tool that can be used for both explanatory analysis, where the focus is on understanding relationships and impact, and prediction, where the goal is to make accurate predictions on new observations. The choice depends on the specific objectives and context of the analysis.
We use the data frame “UniversalBank” from the package "mlba to show an example for classification.
We first create training and holdout data:
# Load the required packages
library(caret)
# Set seed for reproducibility
set.seed(123)
# Get the number of rows in the dataset
n <- nrow(mlba::UniversalBank)
# Define the proportion for training (80%) and holdout (20%)
# Generate random indices for training and holdout
train_indices <- sample(1:n, 0.8 * n)
# Split the data into training and holdout sets
train_data <- mlba::UniversalBank[train_indices, ]
holdout_data <- mlba::UniversalBank[-train_indices, ]
The above code generates random indices corresponding to 80% of the total number of rows in the dataset. Using the random indices generated, this code selects the corresponding rows from the original dataset to create the training set (train_data). The holdout set (holdout_data) is then created by excluding the rows used for training.
Now, train the logistic model using the “glm” method:
# Define the train control
train_control <- trainControl(method = "cv", number = 10)
# Train logistic regression model using the 'train' function
classification_model <- train(
factor(Personal.Loan) ~ . - ID - ZIP.Code,
data = train_data,
method = "glm", # Specify logistic regression
trControl = train_control,
family = "binomial" # Specify binomial family for logistic regression
)
Explanation of code:
This part of the code first sets up the configuration for model training using the trainControl function from the caret package. Here’s what each argument does:
Then the code uses the train function to train a logistic regression model. Here’s what each argument does:
Make predictions on the holdout data and calculate the performance metric:
# Make predictions on the holdout set. For doc about predict() for train(), use ?predict.train
predicted_probabilities <- predict(classification_model, newdata = holdout_data, type = "prob")
# Set a decision threshold (e.g., 0.5) to classify observations
predicted_classes <- ifelse(predicted_probabilities[,2] > 0.5, 1, 0)
# Display the confusion matrix on the holdout set
confusionMatrix(factor(predicted_classes), factor(holdout_data$Personal.Loan), positive = "1")
This above code uses the predict function to generate predicted probabilities for each observation in the holdout set (holdout_data). The type = “prob” argument ensures that the raw predicted probabilities are obtained rather than class labels.
Here, a decision threshold of 0.5 is set. If the predicted probability for an observation is greater than 0.5, it is classified as 1; otherwise, it is classified as 0. This step converts the continuous predicted probabilities into binary class predictions. For a larger threshold, more false negatives will be generated. For a smaller threshold, more false positives will be generated.
The confusionMatrix function from the caret package is used to create a confusion matrix. It compares the predicted classes (predicted_classes) with the actual classes (holdout_data$Personal.Loan). The factor function is used to ensure that both observed and predicted labels are treated as factors in the confusion matrix. The option positive = “1” ensures that the “1” is treated as the positive cases.
The confusion matrix provides information on the performance of the classification model, showing counts of true positives, true negatives, false positives, and false negatives.
k-Nearest Neighbors (knn) is a simple and intuitive algorithm used for both classification and regression tasks in machine learning. The primary idea behind knn is to predict the class or value of a new data point by considering the majority class or average of the k-nearest data points in the feature space.
We explain how knn works for classification of a categorical outcome and prediction of a numeric outcome.
Given a new data point, identify the k-nearest neighbors to the new data point based on a distance metric (commonly Euclidean distance). These neighbors are the data points with the most similar feature values to the new data point. Assign the class label that is most frequent among the k-nearest neighbors to the new data point.
There are some key considerations when using the k-NN method:
The value of k represents the number of neighbors to consider. Choosing the right value of k is crucial and can impact the model’s performance.
Computational efficiency: As the dataset grows, the computational cost of finding neighbors increases.
Sensitivity to irrelevant features: knn considers all features equally, so irrelevant features may affect the model.
Sensitivity to the scale of features: It is recommended to scale features, especially if they have different units or scales.
knn is a non-parametric and instance-based learning algorithm, meaning it makes predictions based on the existing data points without explicitly learning a model. It’s a versatile algorithm used in various applications, but it may not perform well in high-dimensional or noisy datasets.
We use the iris data to build a classifier with Species as the outcome.
library(caret)
# Set seed for reproducibility
set.seed(123)
# Split the iris dataset into training and holdout sets
train_indices <- sample(1:nrow(iris), size = 0.7*nrow(iris))
train_data <- iris[train_indices, ]
holdout_data <- iris[-train_indices, ]
The above code randomly select 70% of the rows for training (train_data) and the remaining 30% for testing (holdout_data) from the iris dataset.
Next, train knn classification model using the ‘train’ function:
# Define the training control for 10-fold cross-validation
train_control <- trainControl(method = "cv", number = 10)
# Optional: tune over different values of k
tune_grid <- expand.grid(k = 1:12)
# Train knn classification model using the 'train' function
knn_model <- train(
Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width,
data = train_data,
method = "knn",
preProcess = c("center", "scale"), # Optional: standardize predictors
trControl = train_control,
tuneGrid = tune_grid
)
# Display the trained model
print(knn_model)
Finally, we produce the performance metrics:
# Make predictions on the testing set
predictions <- predict(knn_model, newdata = holdout_data)
# Display the confusion matrix on the testing set
confusionMatrix(predictions, holdout_data$Species)
# Make predictions on the training set
predictions <- predict(knn_model, newdata = train_data)
# Display the confusion matrix on the train set
confusionMatrix(predictions, train_data$Species)
# # Display the best k used for the final model
knn_model$bestTune$k
Interpretation of results:
High Accuracy: A high accuracy indicates that the model is making correct predictions overall.
Sensitivity and Specificity: These metrics give insights into the model’s performance on positive and negative instances, respectively. A balance between sensitivity and specificity is crucial, depending on the application.
Precision: Precision measures how well the model avoids false positives. A high precision is important when the cost of false positives is high.
F1 Score: A good F1 score indicates a balance between precision and sensitivity. It’s particularly useful when there is an imbalance between the classes.
Given a new data point, identify the k-nearest neighbors to the new data point based on a distance metric. Assign the average value of the target variable of the k-nearest neighbors to the new data point. The choice of the distance metric (commonly Euclidean distance, but others like Manhattan distance can be used).
We will not explain how knn works prediction of a numeric outcome.
Advantages of k-Nearest Neighbors (knn):
Simple and Intuitive: Knn is easy to understand and implement, making it accessible for beginners.
No Assumptions about Data: Knn makes minimal assumptions about the underlying data distribution, making it versatile across different types of datasets.
Handles Nonlinear Relationships: Knn can capture complex, nonlinear relationships in the data.
No Model Assumption: Knn is a non-parametric model, meaning it doesn’t assume a specific functional form for the data.
Disadvantages of k-Nearest Neighbors (knn):
Computationally Expensive: The prediction time in knn can be high, especially for large datasets, as it requires calculating distances for each prediction.
Sensitive to Noise and Outliers: Knn can be sensitive to noisy data and outliers, potentially leading to suboptimal predictions.
Curse of Dimensionality: Knn’s performance may degrade as the number of features (dimensions) increases. In high-dimensional spaces, the concept of proximity becomes less meaningful.
Choice of Distance Metric: The choice of distance metric can significantly impact the performance of knn. The appropriateness of the metric depends on the nature of the data.
Imbalanced Data: Knn can be biased toward the majority class in imbalanced datasets, leading to poor performance on minority classes.
Need for Feature Scaling: Knn is sensitive to the scale of features, and it is recommended to scale features before applying knn to avoid dominance by features with larger scales.
Memory Usage: Knn requires storing the entire training dataset in memory, which can be a limitation for large datasets.
Need for Optimal ‘k’: The choice of the parameter ‘k’ is critical, and selecting an inappropriate value can impact the model’s performance. Cross-validation is often used to find the optimal ‘k’.
In summary, while k-Nearest Neighbors is a simple and flexible algorithm, it comes with trade-offs, especially regarding computational efficiency, sensitivity to noise, and the impact of data characteristics on its performance. The choice to use knn should consider the specific characteristics of the dataset and the computational resources available.
Classification and Regression Trees (CART) are powerful predictive modeling techniques used in machine learning and data mining for both classification and regression tasks. CART is a non-parametric method, meaning it makes minimal assumptions about the form of the underlying data distribution.
Classification trees are used when the target variable is categorical, meaning it falls into a discrete set of classes or categories. For instance, predicting whether an email is spam or not spam, or classifying images of animals into different species. The tree structure consists of decision nodes and leaf nodes. At each decision node, the algorithm selects a feature and a threshold to split the data into two or more subsets. This splitting process continues recursively until each subset contains data from only one class or until a stopping criterion is met. The leaf nodes represent the final predicted class.
Regression trees are employed when the target variable is continuous, such as predicting house prices or forecasting stock prices. Similar to classification trees, regression trees partition the feature space into regions, but instead of predicting class labels at the leaf nodes, they predict numerical values. Each leaf node contains the mean (or median) of the target variable for the observations in that region.
The construction of a decision tree involves recursively partitioning the feature space based on the impurity or homogeneity of the data. Popular impurity measures for classification trees include Gini impurity and entropy (information gain), while for regression trees, the mean squared error (MSE) or variance reduction are commonly used. The goal is to create splits that result in homogeneous subsets with respect to the target variable.
Decision trees have a tendency to overfit the training data, capturing noise. Pruning is a technique used to address this issue by trimming the tree to improve its generalization performance on unseen data. Pruning methods include cost complexity pruning (also known as weakest link pruning) and reduced error pruning.
Advantages of CART:
Limitations of CART:
In practice, ensemble methods like Random Forests and Gradient Boosting Machines are often used to improve the predictive performance of decision trees while mitigating their limitations.
We use the iris data to demonstrate how classification tree can be constructed to classify species.
Create training and holdout data:
Define the Training Control:
# Define the training control
train_control <- trainControl(
method = "cv",
number = 10,
savePredictions = "final",
classProbs = TRUE
)
Define the grid of hyperparameter \(cp\) (complexity parameter) to search over for rpart:
# Define the grid of hyperparameters to search over for rpart
grid <- expand.grid(
cp = seq(0.01, 0.1, by = 0.01)
)
Train the model:
# Load required libraries
library(caret)
library(rpart)
library(rpart.plot)
# Train the model with hyperparameter tuning
model <- train(
Species ~ .,
data = training_data,
method = "rpart",
trControl = train_control,
tuneGrid = grid
)
Note:
Plot the decision tree:
# Plot the decision tree
rpart.plot(model$finalModel, yesno = 2, type = 0, extra = 101)
Explanation of parameters:
Explanation of results:
Make predictions on the holdout data:
# Make predictions on the holdout data
predictions <- predict(model, newdata = holdout_data)
# Evaluate model performance
confusionMatrix(predictions, holdout_data$Species)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 14 0 0
## versicolor 0 18 1
## virginica 0 0 12
##
## Overall Statistics
##
## Accuracy : 0.9778
## 95% CI : (0.8823, 0.9994)
## No Information Rate : 0.4
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.9662
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 1.0000 0.9231
## Specificity 1.0000 0.9630 1.0000
## Pos Pred Value 1.0000 0.9474 1.0000
## Neg Pred Value 1.0000 1.0000 0.9697
## Prevalence 0.3111 0.4000 0.2889
## Detection Rate 0.3111 0.4000 0.2667
## Detection Prevalence 0.3111 0.4222 0.2667
## Balanced Accuracy 1.0000 0.9815 0.9615
Make predictions on the training data:
# Make predictions on the holdout data
predictions <- predict(model, newdata = training_data)
# Evaluate model performance
confusionMatrix(predictions, training_data$Species)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 36 0 0
## versicolor 0 31 4
## virginica 0 1 33
##
## Overall Statistics
##
## Accuracy : 0.9524
## 95% CI : (0.8924, 0.9844)
## No Information Rate : 0.3524
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.9286
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 0.9688 0.8919
## Specificity 1.0000 0.9452 0.9853
## Pos Pred Value 1.0000 0.8857 0.9706
## Neg Pred Value 1.0000 0.9857 0.9437
## Prevalence 0.3429 0.3048 0.3524
## Detection Rate 0.3429 0.2952 0.3143
## Detection Prevalence 0.3429 0.3333 0.3238
## Balanced Accuracy 1.0000 0.9570 0.9386
The model does not performs significantly better on the training set compared to the holdout set, so there is no overfitting issue.
Random Forest and Boosted Trees are both ensemble learning techniques used for classification and regression tasks in machine learning. Here’s a brief overview of the Random Forest method.
library(caret)
# Define the training control
train_control <- trainControl(
method = "cv",
number = 10,
savePredictions = "final",
classProbs = TRUE
)
# Define the grid of hyperparameters for Random Forest
grid <- expand.grid(
mtry = c(2,3,4) # Number of variables randomly sampled as candidates at each split
)
# Train the model with hyperparameter tuning for Random Forest
rf_model <- train(
Species ~ .,
data = training_data,
method = "rf", # Specify Random Forest as the model method
trControl = train_control, # Use the defined training control
tuneGrid = grid # Specify the grid of hyperparameters
)
# Get the tuned Random Forest model
tuned_rf_model <- rf_model$finalModel
tuned_rf_model
##
## Call:
## randomForest(x = x, y = y, mtry = param$mtry)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 3
##
## OOB estimate of error rate: 5.71%
## Confusion matrix:
## setosa versicolor virginica class.error
## setosa 36 0 0 0.00000000
## versicolor 0 29 3 0.09375000
## virginica 0 3 34 0.08108108
Make predictions:
# Make predictions on the training data
training_predictions <- predict(tuned_rf_model, newdata = training_data)
# Make predictions on the holdout data
holdout_predictions <- predict(tuned_rf_model, newdata = holdout_data)
# Create confusion matrix for training data
training_confusion_matrix <- confusionMatrix(training_predictions, training_data$Species)
cat("Confusion Matrix for Training Data:\n\n")
## Confusion Matrix for Training Data:
print(training_confusion_matrix)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 36 0 0
## versicolor 0 32 0
## virginica 0 0 37
##
## Overall Statistics
##
## Accuracy : 1
## 95% CI : (0.9655, 1)
## No Information Rate : 0.3524
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 1
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 1.0000 1.0000
## Specificity 1.0000 1.0000 1.0000
## Pos Pred Value 1.0000 1.0000 1.0000
## Neg Pred Value 1.0000 1.0000 1.0000
## Prevalence 0.3429 0.3048 0.3524
## Detection Rate 0.3429 0.3048 0.3524
## Detection Prevalence 0.3429 0.3048 0.3524
## Balanced Accuracy 1.0000 1.0000 1.0000
# Create confusion matrix for holdout data
holdout_confusion_matrix <- confusionMatrix(holdout_predictions, holdout_data$Species)
cat("Confusion Matrix for Holdout Data:\n\n")
## Confusion Matrix for Holdout Data:
print(holdout_confusion_matrix)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 14 0 0
## versicolor 0 17 0
## virginica 0 1 13
##
## Overall Statistics
##
## Accuracy : 0.9778
## 95% CI : (0.8823, 0.9994)
## No Information Rate : 0.4
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.9664
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 0.9444 1.0000
## Specificity 1.0000 1.0000 0.9688
## Pos Pred Value 1.0000 1.0000 0.9286
## Neg Pred Value 1.0000 0.9643 1.0000
## Prevalence 0.3111 0.4000 0.2889
## Detection Rate 0.3111 0.3778 0.2889
## Detection Prevalence 0.3111 0.3778 0.3111
## Balanced Accuracy 1.0000 0.9722 0.9844
A Boosted Tree is a type of ensemble learning method that combines multiple weak learners (typically decision trees) sequentially to create a strong learner. Boosted Trees belong to the class of gradient boosting algorithms, which iteratively build a series of models, each one correcting the errors made by the previous models.
Here are the key characteristics and concepts related to Boosted Trees:
Gradient Boosting: Boosted Trees use gradient boosting, a gradient descent optimization technique, to iteratively minimize a loss function (such as $loss = _{i=1}^n [-log_2^{p_i}] $ for classification, where \(p_i\) is the predicted class probability and \(loss = \frac{1}{n}\sum_{i=1}^n\frac{1}{2}(y_i-\hat{y_i})^2\) for prediction) by adding new decision trees to the ensemble. Each new tree is fitted to the residual errors of the previous trees (a reference: https://sefiks.com/2018/10/04/a-step-by-step-gradient-boosting-decision-tree-example/).
Sequential Learning: Boosted Trees build trees sequentially, where each subsequent tree focuses on the instances that were misclassified or had high residual errors in the previous iterations. This allows the model to learn from its mistakes and improve over time.
Weighted Voting: In Boosted Trees, each tree is assigned a weight based on its performance, and the final prediction is obtained by weighted voting or averaging. Trees that perform better in reducing the loss function contribute more to the final prediction.
Regularization: Boosted Trees typically include regularization techniques to prevent overfitting, such as shrinkage (learning rate), subsampling, and tree depth constraints.
Versatility: Boosted Trees can be used for both regression and classification tasks. They are highly versatile and are often considered one of the most powerful machine learning algorithms, capable of achieving high predictive accuracy.
library(xgboost)
##
## Attaching package: 'xgboost'
## The following object is masked from 'package:dplyr':
##
## slice
# Train the model for Boosted Trees (XGBoost)
xgb_model <- train(
Species ~ .,
data = training_data,
method = "xgbTree",
verbosity=0
)
# Get the tuned Boosted Trees model (XGBoost)
tuned_xgb_model <- xgb_model$finalModel
Make predictions on new data (must be matrix without response variable):
# Make predictions on new data (must be matrix without response variable)
predictions <- predict(tuned_xgb_model, newdata = as.matrix(holdout_data[, -5]), type = "prob")
# Example: printing the first few predictions
print(head(predictions))
## [1] 0.985935807 0.012500440 0.001563783 0.986802399 0.010428885 0.002768697
Next, we use the UniversalBank data from package mlba to create models and create ROC curves to compare model performance.
Create training and holdout data:
UniversalBank = mlba::UniversalBank
UniversalBank = subset(UniversalBank, select = -c(ID, ZIP.Code))
UniversalBank$Personal.Loan = factor(UniversalBank$Personal.Loan, levels = 0:1, labels = c("No", "Yes"))
set.seed(123) # for reproducibility
trainIndex <- sample(1:nrow(UniversalBank), 0.7*nrow(UniversalBank))
training_data <- UniversalBank[trainIndex, ] # 70% of data for training
holdout_data <- UniversalBank[-trainIndex, ] # 30% of data for holdout
Train 3 different models:
# 1. Train rpart model
rpart_model <- train(Personal.Loan ~ .,
data = training_data,
method = "rpart"
)
# 2. Train Random Forest model
rf_model <- train(Personal.Loan ~ .,
data = training_data,
method = "rf")
# 3. Train Boosted Tree model (XGBoost)
xgb_model <- train(Personal.Loan ~ .,
data = training_data,
method = "xgbTree",
verbosity=0)
Note: the parameter “scale_pos_weight” can be added the the train() function when training the boosted tree model. This parameter is used to address class imbalance in binary classification problems. When you have imbalanced classes, meaning one class (typically the minority class, often denoted as the positive class) has significantly fewer samples than the other class (often denoted as the negative class), the model can become biased towards the majority class. This can lead to poor performance, particularly in terms of correctly predicting the minority class.
The scale_pos_weight parameter is used to give more weight to the samples of the positive class during training. It essentially scales the gradient for the positive class, effectively balancing the contribution of both classes during training. By increasing the weight of the positive class, the model pays more attention to it, which can help in improving the predictive performance, especially in cases of severe class imbalance.
The value of scale_pos_weight is typically set to the ratio of negative class examples to positive class examples. For example, if you have 120 negative class examples and 10 positive class examples, you might set scale_pos_weight to 12 (120/10).
Here’s a brief summary of what scale_pos_weight does:
Increases weight for positive class: It assigns a higher weight to the samples from the positive class during training. Addresses class imbalance: Helps to mitigate the impact of imbalanced class distributions in binary classification problems. Improves predictive performance: By giving more weight to the minority class, it encourages the model to learn better representations of the positive class, leading to improved performance metrics, such as precision, recall, and F1-score.
Make predictions:
# Predict probabilities on holdout data
rpart_prob <- predict(rpart_model, holdout_data, type = "prob")[, "Yes"]
rf_prob <- predict(rf_model, holdout_data, type = "prob")[, "Yes"]
xgb_prob <- predict(xgb_model, holdout_data[, -8], type = "prob")[, "Yes"]
Create an ROC curve for each model:
# Load required libraries
library(caret)
library(pROC)
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
library(ggplot2)
# Create ROC curves
rpart_roc <- roc(holdout_data$Personal.Loan, rpart_prob)
## Setting levels: control = No, case = Yes
## Setting direction: controls < cases
rf_roc <- roc(holdout_data$Personal.Loan, rf_prob)
## Setting levels: control = No, case = Yes
## Setting direction: controls < cases
xgb_roc <- roc(holdout_data$Personal.Loan, xgb_prob)
## Setting levels: control = No, case = Yes
## Setting direction: controls < cases
# Extract sensitivity and specificity values from ROC curves
rpart_df <- data.frame(Sensitivity = rpart_roc$sensitivities,
Specificity = 1 - rpart_roc$specificities)
rf_df <- data.frame(Sensitivity = rf_roc$sensitivities,
Specificity = 1 - rf_roc$specificities)
xgb_df <- data.frame(Sensitivity = xgb_roc$sensitivities,
Specificity = 1 - xgb_roc$specificities)
# Plot ROC curves using ggplot2 with legend
ggplot() +
geom_line(data = rpart_df, aes(x = Specificity, y = Sensitivity, color = "rpart"), linetype = 2) +
geom_line(data = rf_df, aes(x = Specificity, y = Sensitivity, color = "Random Forest"), linetype = 2) +
geom_line(data = xgb_df, aes(x = Specificity, y = Sensitivity, color = "Boosted Tree"), linetype = 2) +
labs(title = "ROC Curves for rpart, Random Forest, and Boosted Tree",
x = "False Positive Rate (1 - Specificity)",
y = "True Positive Rate (Sensitivity)") +
scale_color_manual(values = c("rpart" = "blue", "Random Forest" = "red", "Boosted Tree" = "green")) +
theme_minimal() +
theme(legend.position = "bottom")
Neural networks, also called artificial neural networks (ANN), are models for classification or prediction, in addition to k-NN’s and decision trees. Check out the images of many neural nets via https://www.google.com/.
A neural network has an input layer, zero, one, or more hidden layers, and an output layer.
(Source: https://towardsdatascience.com/designing-your-neural-networks-a5e4617027ed)
A network without hidden layer is called a perceptron. A network with one or more hidden layers is called a multilayer perceptron (MLP). The input layer is composed of input neurons (or just called nodes). The inputs are the values of predictors (called features) of a single observation (such as an individual, an image, or a document). Each hidden layer has a few neurons. The output layer can be one or more neurons. A neuron can be activated by other neurons to which it is connected. Each directed line connecting two neurons is called a synapse. The strength of connection between two neurons is quantified by the weight.
For the prediction of a single numerical variable, one output node is needed. For a classification problem, the number of output nodes equals the number of classes or the number of classes minus one. A tutorial on neural networks can be found here: https://ujjwalkarn.me/2016/08/09/quick-intro-neural-networks/
A very nice tutorial for detailed computation of outputs of a neural network is: https://stevenmiller888.github.io/mind-how-to-build-a-neural-network/.
In the following graph, a neural network with inputs and weights is shown. We explain how some numbers are obtained.
How to compute the output of a neural network?
Some computation details:
Both inputs are 1 (say for an observation in your training data), so their weighted contribution to
the first node in the hidden layer is \(1(0.8) + 1(0.2) = 1.0\)
the second node in the hidden layer is \(1(0.4) + 1(0.9) = 1.3\)
the third node in the hidden layer is \(1(0.3) + 1(0.5) = 0.8\).
After activation with the logistic function
\[f(x) = \frac{1}{1+e^{-x}}\]
the output from each of these 3 nodes is
\(f(1.0) = \frac{1}{1+e^{-1.0}} = 0.73\)
\(f(1.3) = \frac{1}{1+e^{-1.3}} = 0.79\)
\(f(0.8) = \frac{1}{1+e^{-0.8}} = 0.69\)
Finally, use the numbers 0.73, 0.79, and 0.69 as inputs to get the weighted contribution \(0.73(0.3) + 0.79(0.5) + 0.69(0.9)\) or 1.2 to the output node. After activation with the same logistic function, the output is 0.77.
A feed-forward neural network is an artificial neural network where the connections (called edges) between two consecutive layers do not form a cycle. It is a directed graph. In a dense neural network, each neuron (say \(i\)) in a layer is connected to each neuron (say \(j\)) in the next layer. Each edge is associated with a weight (denoted by \(w_{ij}\)) describing the contribution of \(i\) to \(j\). Techniquely, a dummy node is added in the input layer and each hidden layer. All dummy nodes take the same value of 1. The textbook uses \(\theta_k\) to denote the contribution of the \(k\)th dummy node. These \(\theta_k\)’s are called biases.
To train a neural network, follow the following steps:
Pick a neural network architecture. What is the number of input nodes (equal to the number of features)? What is the number of hidden layers (the most common is 1 or 2)? What is the number of nodes in each of the hidden layers? This is a hard issue. Some suggest that the number of nodes in each hidden layer be roughly equal to the mean of the nodes in the input and output layers. What is the number of output nodes? The number of output nodes either equals 1 (for regression) or equals the number of classes (for classification).
Dummify nominal features, score ordinal features using values between 0 and 1, and normalize numerical features to \([0, 1]\) or \([-1,1]\).
Choose initial weights: The weights are randomly chosen to be very close to 0, such as between \(-0.1\) and 0.1.
Choose an activation function. Activation functions are what make a neural network model nonlinear. Examples of activation functions include the identity function, logistic (or sigmoid) function, hyperbolic tangent function (or tanh), and ReLU (Rectified Linear Unit).
Choose an error or cost function for optimizing weights. The R package “neuralnet” provides two error functions through the argument “err.fct”. One is the “sum of squared error (sse)” function, defined as
\[E_{sse}=\sum_{i=1}^{n}\sum_{j=1}^{c}\frac{1}{2}(y_{ij}-\hat{y}_{ij})^2\] This error function can be used for both regression (with quantitative response) and classification (with nominal response).
For classification problem, the error function can be chosen to be the “cross-entropy (ce)” function, defined as
\[E_{ce}=-\sum_{i=1}^{n}\sum_{j=1}^{c}y_{ij}\cdot log(\hat{y}_{ij}),\]
where \(n\) is the number of observations, \(c\) is the number of output nodes, \(y_{ij}\) represents the observed value of the \(i\)th observation for the \(j\)th output node, and \(\hat{y}_{ij}\) represents the corresponding predicted propensities, (usually) based on the softmax transformation. A softmax transformation is a transformation that transform a vector \(x\) to \(\frac{e^x}{\sum e^x}\). For example, the vector \((0.4, 0.9, 1.5)\) can be transformed to
\[(\frac{e^{0.4}}{e^{0.4}+e^{0.9}+e^{1.5}}, \frac{e^{0.9}}{e^{0.4}+e^{0.9}+e^{1.5}}, \frac{e^{1.5}}{e^{0.4}+e^{0.9}+e^{1.5}})\]
or \((0.1769, 0.2917, 0.5314)\), the elements of which add up to 1.
Use gradient descent method and the back propagation (BP) algorithm (1986, by Geoffrey Hinton) to update weights. The best tutorial: http://home.agh.edu.pl/~vlsi/AI/backp_t_en/backprop.html
Apply stopping rules to avoid overfitting (large variance). R uses stepmax = 100000 by default and threshold = 0.01 by default for the partial derivative of the error function with respect to weights. One can try different stepmax values and plot errors based on train and validation sets. The curve based on training data should decrease but the curve based on validation data should decrease and then increase. The best stepmax value correspond to the turning point.
Once a neural network architecture is selected, a learning procedure, called back-propagation, is used to repeatedly adjust the weights of the connections in the network so as to minimize a measure (called an error, loss, or cost function) of the difference between the actual output vector of the net and the desired output vector. This article, https://www.nature.com/articles/323533a0, describes the back-propagation procedure.
For the first record in the training data, the model starts with a set of initial weights usually randomly chosen by software. When the output node(or nodes) produces(or produce) an output, it is compared with the actual outcome value for the first record. The error is back-propagated to each hidden nodes, and the input weights and all other weights then get updated in the same way using the gradient descent method. This is demonstrated in the wonderful tutorial: http://home.agh.edu.pl/~vlsi/AI/backp_t_en/backprop.html. This series of videos may also help: https://www.youtube.com/watch?v=5u0jaA3qAGk
R packages such as “nnet” and “neuralnet” can be used to fit neural networks. The book uses “neuralnet”. Please use the newest version of “neuralnet”. On the R console, issue the command:
devtools::install_github(“bips-hb/neuralnet”)
A neuron has inputs and an output. The output is obtained by applying an activation function to its net input (weighted sum of inputs). For a list of commonly used activation functions, refer to https://en.wikipedia.org/wiki/Activation_function. The activation functions R uses are
The “logistic” function:
\[logistic(x)=\frac{1}{1+e^{-x}}\]
The “tanh” function:
\[tanh(x)=\frac{e^x-e^{-x}}{e^x+e^{-x}}\]
The “ReLU” function:
\[ReLU(x) = max(0, x)\]
A smoothed version of ReLU is the softplus function, which is defines as \[softplus(x)=log(1+e^x)\].
Here are the plots of those functions:
We will use the function neuralnet() from the "neuralnet package to train a neural network.
Let’s use the following data to demonstrate how a neural network can be trained:
df = data.frame(Salt = c(0.9, 0.1, 0.4, 0.5, 0.5, 0.8),
Fat = c(0.2, 0.1, 0.2, 0.2, 0.4, 0.3),
Acceptance = c("like", "dislike", "dislike", "dislike", "like", "like")
)
df
## Salt Fat Acceptance
## 1 0.9 0.2 like
## 2 0.1 0.1 dislike
## 3 0.4 0.2 dislike
## 4 0.5 0.2 dislike
## 5 0.5 0.4 like
## 6 0.8 0.3 like
We first create a neural network using the default setting:
library(neuralnet)
##
## Attaching package: 'neuralnet'
## The following object is masked from 'package:dplyr':
##
## compute
## The following object is masked from 'package:ROCR':
##
## prediction
nn <- neuralnet(
Acceptance ~ Salt + Fat,
data = df,
hidden = c(3, 2), # One hidden layer with 3 nodes
linear.output = FALSE # For classification
)
nn
## $call
## neuralnet(formula = Acceptance ~ Salt + Fat, data = df, hidden = c(3,
## 2), linear.output = FALSE)
##
## $response
## dislike like
## 1 FALSE TRUE
## 2 TRUE FALSE
## 3 TRUE FALSE
## 4 TRUE FALSE
## 5 FALSE TRUE
## 6 FALSE TRUE
##
## $covariate
## Salt Fat
## [1,] 0.9 0.2
## [2,] 0.1 0.1
## [3,] 0.4 0.2
## [4,] 0.5 0.2
## [5,] 0.5 0.4
## [6,] 0.8 0.3
##
## $model.list
## $model.list$response
## [1] "dislike" "like"
##
## $model.list$variables
## [1] "Salt" "Fat"
##
##
## $err.fct
## function (x, y)
## {
## 1/2 * (y - x)^2
## }
## <bytecode: 0x7ff3b38b5e18>
## <environment: 0x7ff3b38b4198>
## attr(,"type")
## [1] "sse"
##
## $act.fct
## function (x)
## {
## 1/(1 + exp(-x))
## }
## <bytecode: 0x7ff3b38c8390>
## <environment: 0x7ff3b38c56a8>
## attr(,"type")
## [1] "logistic"
##
## $output.act.fct
## function (x)
## {
## 1/(1 + exp(-x))
## }
## <bytecode: 0x7ff3b38c8390>
## <environment: 0x7ff3b38c5b40>
## attr(,"type")
## [1] "logistic"
##
## $linear.output
## [1] FALSE
##
## $data
## Salt Fat Acceptance
## 1 0.9 0.2 like
## 2 0.1 0.1 dislike
## 3 0.4 0.2 dislike
## 4 0.5 0.2 dislike
## 5 0.5 0.4 like
## 6 0.8 0.3 like
##
## $exclude
## NULL
##
## $net.result
## $net.result[[1]]
## [,1] [,2]
## [1,] 0.5021693 0.5015714
## [2,] 0.4988113 0.4986909
## [3,] 0.4993942 0.5003639
## [4,] 0.4999647 0.5006180
## [5,] 0.4969492 0.5016942
## [6,] 0.5001782 0.5018405
##
##
## $weights
## $weights[[1]]
## $weights[[1]][[1]]
## [,1] [,2] [,3]
## [1,] 1.3341348 0.9865734 -0.5129866
## [2,] 0.2012335 0.3957058 0.3502636
## [3,] -1.0250362 2.3439686 -1.6767225
##
## $weights[[1]][[2]]
## [,1] [,2]
## [1,] -1.9136941 -0.6076783
## [2,] -1.2177018 0.9069403
## [3,] 0.2256041 1.2386718
## [4,] -1.3977674 0.5626226
##
## $weights[[1]][[3]]
## [,1] [,2]
## [1,] -0.3717834 -0.5057464
## [2,] -1.6929532 0.7538025
## [3,] 0.5595842 0.6038245
##
##
##
## $generalized.weights
## $generalized.weights[[1]]
## [,1] [,2] [,3] [,4]
## [1,] 0.02142582 -0.05461309 0.009048639 0.02090921
## [2,] 0.02452375 -0.04279089 0.012680990 0.03759546
## [3,] 0.02297165 -0.05273932 0.010296640 0.02762620
## [4,] 0.02266438 -0.05321508 0.010034495 0.02621493
## [5,] 0.02177251 -0.06668553 0.007498394 0.01716901
## [6,] 0.02143919 -0.06160589 0.008053218 0.01777592
##
##
## $startweights
## $startweights[[1]]
## $startweights[[1]][[1]]
## [,1] [,2] [,3]
## [1,] 0.7657348 0.8885734 -1.081387
## [2,] 0.5696335 -0.1726942 1.250264
## [3,] -0.6566362 1.7755686 -1.108323
##
## $startweights[[1]][[2]]
## [,1] [,2]
## [1,] -1.2136941 -0.9876783
## [2,] -0.9177018 0.5269403
## [3,] 0.4706041 0.8586718
## [4,] -1.2102674 0.1826226
##
## $startweights[[1]][[3]]
## [,1] [,2]
## [1,] -0.7267834 -0.8857464
## [2,] -2.1029532 0.3738025
## [3,] 0.2595842 0.2238245
##
##
##
## $result.matrix
## [,1]
## error 1.497860045
## reached.threshold 0.004470739
## steps 10.000000000
## Intercept.to.1layhid1 1.334134753
## Salt.to.1layhid1 0.201233538
## Fat.to.1layhid1 -1.025036225
## Intercept.to.1layhid2 0.986573415
## Salt.to.1layhid2 0.395705850
## Fat.to.1layhid2 2.343968643
## Intercept.to.1layhid3 -0.512986628
## Salt.to.1layhid3 0.350263626
## Fat.to.1layhid3 -1.676722513
## Intercept.to.2layhid1 -1.913694139
## 1layhid1.to.2layhid1 -1.217701842
## 1layhid2.to.2layhid1 0.225604073
## 1layhid3.to.2layhid1 -1.397767448
## Intercept.to.2layhid2 -0.607678263
## 1layhid1.to.2layhid2 0.906940260
## 1layhid2.to.2layhid2 1.238671758
## 1layhid3.to.2layhid2 0.562622639
## Intercept.to.dislike -0.371783437
## 2layhid1.to.dislike -1.692953237
## 2layhid2.to.dislike 0.559584192
## Intercept.to.like -0.505746407
## 2layhid1.to.like 0.753802477
## 2layhid2.to.like 0.603824489
##
## attr(,"class")
## [1] "nn"
plot(nn) # Plot the neuralnet network
nn$net.result[[1]] # The probability of each sample to be a class (in alphabetical order)
## [,1] [,2]
## [1,] 0.5021693 0.5015714
## [2,] 0.4988113 0.4986909
## [3,] 0.4993942 0.5003639
## [4,] 0.4999647 0.5006180
## [5,] 0.4969492 0.5016942
## [6,] 0.5001782 0.5018405
# 1 = Dislike, 2 = Like, as printed in previous plot
Each row may not sum to 1!!!
We next make predictions. The predictions usually are done based on holdout data. For demo purposes, predictions are just based on the training data:
predict(nn,newdata = df)
## [,1] [,2]
## [1,] 0.5021693 0.5015714
## [2,] 0.4988113 0.4986909
## [3,] 0.4993942 0.5003639
## [4,] 0.4999647 0.5006180
## [5,] 0.4969492 0.5016942
## [6,] 0.5001782 0.5018405
Next, we create a neural network using more parameters:
# We will fit a neural network with 2 input nodes, one hidden layer with 3 nodes, another hidden layer with 2 nodes, and 2 output nodes. There will be 23 weights to estimate. The software can choose initial weights.
nn <- neuralnet(
Acceptance ~ Salt + Fat,
data = df,
hidden = c(3, 2),
linear.output = FALSE, # For classification, set it to FALSE
err.fct = "sse", # The error function can be "sse" (default) or "ce"
rep = 5 # You can repeat the training process many times (called epochs) with default = 1
)
We then plot the structure of the neural net.
plot(nn, "best") # The best of the epochs (With the lowest error); the first happens to be best
Prediction for the training data, based on EACH repeat:
prediction(nn)
## Data Error: 0;
## $rep1
## Salt Fat dislike like
## 1 0.1 0.1 0.99070989 0.005253942
## 2 0.4 0.2 0.96994700 0.020050745
## 3 0.5 0.2 0.89127574 0.086148034
## 4 0.9 0.2 0.05340583 0.959565003
## 5 0.8 0.3 0.04889151 0.963454283
## 6 0.5 0.4 0.12867507 0.890273205
##
## $rep2
## Salt Fat dislike like
## 1 0.1 0.1 0.97195365 0.03241633
## 2 0.4 0.2 0.94216709 0.06510366
## 3 0.5 0.2 0.88063664 0.13082785
## 4 0.9 0.2 0.02834545 0.97042331
## 5 0.8 0.3 0.02138761 0.97745430
## 6 0.5 0.4 0.10581642 0.89353446
##
## $rep3
## Salt Fat dislike like
## 1 0.1 0.1 0.99827805 0.007614863
## 2 0.4 0.2 0.98124639 0.047544132
## 3 0.5 0.2 0.91207961 0.150419989
## 4 0.9 0.2 0.02197600 0.957691086
## 5 0.8 0.3 0.01973454 0.961080437
## 6 0.5 0.4 0.11285110 0.850325466
##
## $rep4
## Salt Fat dislike like
## 1 0.1 0.1 0.99565373 0.006197649
## 2 0.4 0.2 0.98107375 0.022473641
## 3 0.5 0.2 0.92255191 0.082906177
## 4 0.9 0.2 0.02245076 0.971365131
## 5 0.8 0.3 0.01767057 0.976870962
## 6 0.5 0.4 0.06943426 0.917621242
##
## $rep5
## Salt Fat dislike like
## 1 0.1 0.1 0.98446357 0.02372746
## 2 0.4 0.2 0.96423051 0.04934859
## 3 0.5 0.2 0.90688561 0.11452660
## 4 0.9 0.2 0.04813151 0.93613709
## 5 0.8 0.3 0.03641853 0.95001397
## 6 0.5 0.4 0.08840404 0.89103609
##
## $data
## Salt Fat dislike like
## 1 0.1 0.1 1 0
## 2 0.4 0.2 1 0
## 3 0.5 0.2 1 0
## 4 0.9 0.2 0 1
## 5 0.8 0.3 0 1
## 6 0.5 0.4 0 1
Prediction for the training data, based on the best epoch:
predict(nn,newdata = df,
rep = which.min(nn$result.matrix[1,])) # Prediction based on best epoch
## [,1] [,2]
## [1,] 0.02245076 0.971365131
## [2,] 0.99565373 0.006197649
## [3,] 0.98107375 0.022473641
## [4,] 0.92255191 0.082906177
## [5,] 0.06943426 0.917621242
## [6,] 0.01767057 0.976870962
Next, let’s create a neural network for classifying species based on the iris data.
We first split the data to training and holdout sets.
library(neuralnet)
set.seed(123)
n = nrow(iris)
train.idx = sample(1:n, n*0.6)
train = iris[train.idx, ]
holdout=iris[-train.idx, ]
For classification, the response can be a single categorical variable or all dummy variables of the categorical variable. The propensities for each record may not add up to 1.
We demonstrate the method of using a single response variable (all other variables as features).
nn <- neuralnet(formula = Species ~ .,
data = train,
hidden = c(1), # One hidden lay with one node
act.fct = "logistic", # The activation function can also be defined by the user.
linear.output = FALSE,
rep = 10, # the number of repetitions for the neural network's training. Also known as Epoch.
# The epoch with the least error is reported.
# Use code: which.min(nn$result.matrix[1,]) to get the best epoch.
# The results are obtained by nn$result.matrix[, k], assuming the best epoch is k.
lifesign="full", # A string specifying how much the function will print during # the calculation of the neural network. 'none', 'minimal' or 'full'.
)
## hidden: 1 thresh: 0.01 rep: 1/10 steps: 1000 min thresh: 0.226290033445551
## 2000 min thresh: 0.226290033445551
## 3000 min thresh: 0.170990940362761
## 4000 min thresh: 0.170990940362761
## 5000 min thresh: 0.122965259698456
## 6000 min thresh: 0.122965259698456
## 7000 min thresh: 0.0813866774964914
## 8000 min thresh: 0.0813866774964914
## 9000 min thresh: 0.0719146346478065
## 10000 min thresh: 0.0676218325090381
## 11000 min thresh: 0.0527273032348466
## 12000 min thresh: 0.0483990976230184
## 13000 min thresh: 0.0471145667873109
## 14000 min thresh: 0.0447387170517326
## 15000 min thresh: 0.0404548916370239
## 16000 min thresh: 0.0322734178518158
## 17000 min thresh: 0.024089283557897
## 18000 min thresh: 0.024089283557897
## 19000 min thresh: 0.024089283557897
## 20000 min thresh: 0.0194968564540944
## 21000 min thresh: 0.0194968564540944
## 22000 min thresh: 0.0178898180988451
## 23000 min thresh: 0.0178898180988451
## 24000 min thresh: 0.0178898180988451
## 25000 min thresh: 0.0155946222087751
## 26000 min thresh: 0.0154313369726457
## 27000 min thresh: 0.0129109217385888
## 28000 min thresh: 0.0122953862911799
## 29000 min thresh: 0.0121907195109253
## 30000 min thresh: 0.0119079499589907
## 31000 min thresh: 0.0119079499589907
## 32000 min thresh: 0.0107827623429392
## 32614 error: 7.58764 time: 4.19 secs
## hidden: 1 thresh: 0.01 rep: 2/10 steps: 1000 min thresh: 0.045625832802289
## 2000 min thresh: 0.045625832802289
## 3000 min thresh: 0.045625832802289
## 4000 min thresh: 0.045625832802289
## 5000 min thresh: 0.0437694107437924
## 6000 min thresh: 0.0332078729907998
## 7000 min thresh: 0.0256439389001248
## 8000 min thresh: 0.0200072834720572
## 9000 min thresh: 0.0173967870662027
## 10000 min thresh: 0.0159303305980142
## 11000 min thresh: 0.01331702678696
## 12000 min thresh: 0.0111912092809883
## 13000 min thresh: 0.0109065438280963
## 13275 error: 7.66328 time: 1.49 secs
## hidden: 1 thresh: 0.01 rep: 3/10 steps: 1000 min thresh: 0.0775145204849547
## 2000 min thresh: 0.0526637477863611
## 3000 min thresh: 0.0509576241570747
## 4000 min thresh: 0.0509576241570747
## 5000 min thresh: 0.0509576241570747
## 6000 min thresh: 0.0509576241570747
## 7000 min thresh: 0.0509576241570747
## 8000 min thresh: 0.046199217989409
## 9000 min thresh: 0.046199217989409
## 10000 min thresh: 0.046199217989409
## 11000 min thresh: 0.0446697368930781
## 12000 min thresh: 0.031455253897322
## 13000 min thresh: 0.0309238811124102
## 14000 min thresh: 0.0239402158555448
## 15000 min thresh: 0.0239402158555448
## 16000 min thresh: 0.0226763062363443
## 17000 min thresh: 0.0214722054900754
## 18000 min thresh: 0.0213680650649877
## 19000 min thresh: 0.0195458352881196
## 20000 min thresh: 0.0195458352881196
## 21000 min thresh: 0.0176881448909062
## 22000 min thresh: 0.0164666491313533
## 23000 min thresh: 0.0152283101103131
## 24000 min thresh: 0.0144557160117792
## 25000 min thresh: 0.0107045903571163
## 26000 min thresh: 0.0107045903571163
## 27000 min thresh: 0.0107045903571163
## 28000 min thresh: 0.0107045903571163
## 28026 error: 8.10857 time: 3.19 secs
## hidden: 1 thresh: 0.01 rep: 4/10 steps: 81 error: 14.74587 time: 0.01 secs
## hidden: 1 thresh: 0.01 rep: 5/10 steps: 1000 min thresh: 0.0682102641114169
## 2000 min thresh: 0.036774251587861
## 3000 min thresh: 0.0286527132298637
## 4000 min thresh: 0.016737168397448
## 5000 min thresh: 0.016392949920749
## 6000 min thresh: 0.016392949920749
## 7000 min thresh: 0.016392949920749
## 8000 min thresh: 0.016392949920749
## 9000 min thresh: 0.016392949920749
## 10000 min thresh: 0.016392949920749
## 11000 min thresh: 0.016392949920749
## 12000 min thresh: 0.016392949920749
## 13000 min thresh: 0.016392949920749
## 14000 min thresh: 0.016392949920749
## 15000 min thresh: 0.016392949920749
## 16000 min thresh: 0.016392949920749
## 17000 min thresh: 0.016392949920749
## 18000 min thresh: 0.016392949920749
## 19000 min thresh: 0.016392949920749
## 20000 min thresh: 0.016392949920749
## 21000 min thresh: 0.016392949920749
## 22000 min thresh: 0.0145447916328482
## 23000 min thresh: 0.0145447916328482
## 24000 min thresh: 0.0145447916328482
## 25000 min thresh: 0.0134240013645858
## 26000 min thresh: 0.0134240013645858
## 27000 min thresh: 0.0125820518392892
## 28000 min thresh: 0.0116032839965431
## 29000 min thresh: 0.0116032839965431
## 30000 min thresh: 0.01062405798604
## 30765 error: 7.59595 time: 3.34 secs
## hidden: 1 thresh: 0.01 rep: 6/10 steps: 1000 min thresh: 0.0551374865169311
## 2000 min thresh: 0.017299565885074
## 3000 min thresh: 0.013650117396441
## 4000 min thresh: 0.013650117396441
## 5000 min thresh: 0.013650117396441
## 6000 min thresh: 0.013650117396441
## 7000 min thresh: 0.013650117396441
## 8000 min thresh: 0.013650117396441
## 9000 min thresh: 0.013650117396441
## 10000 min thresh: 0.013650117396441
## 11000 min thresh: 0.0122373714170887
## 12000 min thresh: 0.0122373714170887
## 13000 min thresh: 0.0109153880994002
## 13864 error: 7.65549 time: 1.52 secs
## hidden: 1 thresh: 0.01 rep: 7/10 steps: 1000 min thresh: 0.0871226420436797
## 2000 min thresh: 0.0229203903188004
## 3000 min thresh: 0.0144555610812959
## 4000 min thresh: 0.0144555610812959
## 5000 min thresh: 0.0144555610812959
## 6000 min thresh: 0.0144555610812959
## 7000 min thresh: 0.0144555610812959
## 8000 min thresh: 0.0144555610812959
## 9000 min thresh: 0.0144555610812959
## 10000 min thresh: 0.0144555610812959
## 11000 min thresh: 0.0143110354013932
## 12000 min thresh: 0.0123015428123427
## 13000 min thresh: 0.0107918642966011
## 13648 error: 7.65809 time: 1.49 secs
## hidden: 1 thresh: 0.01 rep: 8/10 steps: 1000 min thresh: 0.0555194552434052
## 2000 min thresh: 0.0419509966141126
## 3000 min thresh: 0.0312270973445072
## 4000 min thresh: 0.0229761450015076
## 5000 min thresh: 0.0198829014140602
## 6000 min thresh: 0.019762882536451
## 7000 min thresh: 0.019762882536451
## 8000 min thresh: 0.019762882536451
## 9000 min thresh: 0.019762882536451
## 10000 min thresh: 0.019762882536451
## 11000 min thresh: 0.019762882536451
## 12000 min thresh: 0.019762882536451
## 13000 min thresh: 0.019762882536451
## 14000 min thresh: 0.019762882536451
## 15000 min thresh: 0.019762882536451
## 16000 min thresh: 0.0186629943661331
## 17000 min thresh: 0.0186629943661331
## 18000 min thresh: 0.0186629943661331
## 19000 min thresh: 0.0186629943661331
## 20000 min thresh: 0.0186629943661331
## 21000 min thresh: 0.0179518576993682
## 22000 min thresh: 0.0157406202346368
## 23000 min thresh: 0.0157406202346368
## 24000 min thresh: 0.0157406202346368
## 25000 min thresh: 0.0157406202346368
## 26000 min thresh: 0.0143947880418125
## 27000 min thresh: 0.0128495345943573
## 28000 min thresh: 0.0128495345943573
## 29000 min thresh: 0.0121264822871384
## 30000 min thresh: 0.011368057615844
## 31000 min thresh: 0.0112682318589401
## 32000 min thresh: 0.0112076447650147
## 33000 min thresh: 0.0112076447650147
## 34000 min thresh: 0.0107369651752861
## 34628 error: 7.57914 time: 4.11 secs
## hidden: 1 thresh: 0.01 rep: 9/10 steps: 1000 min thresh: 0.0616057803283966
## 2000 min thresh: 0.0616057803283966
## 3000 min thresh: 0.0557541426853788
## 4000 min thresh: 0.054036380176709
## 5000 min thresh: 0.0489066554890847
## 6000 min thresh: 0.0368786549480171
## 7000 min thresh: 0.029938533708948
## 8000 min thresh: 0.0240752954252181
## 9000 min thresh: 0.0216934502752998
## 10000 min thresh: 0.0166177755217189
## 11000 min thresh: 0.0153109125277142
## 12000 min thresh: 0.0130207702761637
## 13000 min thresh: 0.0123190887825177
## 14000 min thresh: 0.0101175708663564
## 14289 error: 7.65031 time: 1.55 secs
## hidden: 1 thresh: 0.01 rep: 10/10 steps: 19 error: 29.82224 time: 0 secs
Plot the network. If rep=“best”, the repetition (or epoch) with the smallest error will be plotted. If not stated all repetitions will be plotted, each in a separate window.
plot(nn, rep = "best") # For help, use ?plot.nn
We then make predictions based on the holdout data:
pred = predict(nn,
newdata = holdout,
rep = which.min(nn$result.matrix[1,]) # Prediction based on the best epoch
# with lowest error
)
head(pred) # Display propensity scores
## [,1] [,2] [,3]
## 1 0.9708753 0.4474900 1.984616e-42
## 2 0.9708697 0.4474899 1.984714e-42
## 3 0.9708712 0.4474899 1.984687e-42
## 5 0.9708753 0.4474900 1.984616e-42
## 11 0.9708764 0.4474900 1.984597e-42
## 15 0.9708773 0.4474900 1.984581e-42
# Get predicted labels
predicted.class = apply(pred, 1, which.max) %>% as.numeric()
predicted.class # Note that: The columns are in the alphabetical order of the different values in the response variable
## [1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 3 3 2 2 1 2
## [39] 3 3 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
# Convert predicted.class to factor in order to create confusion matrix
predicted = factor(predicted.class, levels = 1:3, labels = c("setosa", "versicolor", "virginica"))
# Convert actual column in test data to factor in order to create confusion matrix
actual = holdout$Species %>% factor(levels = c("setosa", "versicolor", "virginica"))
caret::confusionMatrix(predicted, actual)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 20 1 0
## versicolor 0 19 0
## virginica 0 4 16
##
## Overall Statistics
##
## Accuracy : 0.9167
## 95% CI : (0.8161, 0.9724)
## No Information Rate : 0.4
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.8752
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 0.7917 1.0000
## Specificity 0.9750 1.0000 0.9091
## Pos Pred Value 0.9524 1.0000 0.8000
## Neg Pred Value 1.0000 0.8780 1.0000
## Prevalence 0.3333 0.4000 0.2667
## Detection Rate 0.3333 0.3167 0.2667
## Detection Prevalence 0.3500 0.3167 0.3333
## Balanced Accuracy 0.9875 0.8958 0.9545
Next, we use the book data “accidentsnn.csv” to create a neural network. Each row of the data table corresponds to a US automobile accident classified by its level of severity as no injury, injury, or fatality. The purpose is to develop a system for quickly classifying the severity of an accident, based on initial reports and associated data in the system (some of which rely on GPS-assisted reporting). Such a system could be used to assign emergency response team priorities. Here is a description of the variables:
ALCHL_I: Presence (1) or absence (2) of alcohol
PROFIL_I_R: Profile of the road way (level = 1, other = 0)
SUR_COND: Surface condition of the road (dry = 1, wet = 2, snow/slush = 3, ice = 4, unknown = 9)
VEH_INVL: Number of vehicles involved
MAX_SEV_IR: Presence of injuries/fatalities (no injury = 0, injury = 1, fatality = 2)
Data pre-processing:
accidents.df <- read.csv("/Users/home/Documents/Zhang/Stat415.515.615/DMBA-R-datasets/accidentsnn.csv")
head(accidents.df)
## ALCHL_I PROFIL_I_R SUR_COND VEH_INVL MAX_SEV_IR
## 1 2 0 1 1 0
## 2 2 1 1 1 2
## 3 1 0 1 1 0
## 4 2 0 2 2 1
## 5 2 1 1 2 1
## 6 2 0 1 1 0
# Dummify categorical variables
accidents.df$SUR_COND_1 = (accidents.df$SUR_COND==1) * 1
accidents.df$SUR_COND_2 = (accidents.df$SUR_COND==2) * 1
accidents.df$SUR_COND_3 = (accidents.df$SUR_COND==3) * 1
accidents.df$SUR_COND_4 = (accidents.df$SUR_COND==4) * 1
accidents.df$SUR_COND_9 = (accidents.df$SUR_COND==9) * 1
accidents.df$ALCHL_I_1 = (accidents.df$ALCHL_I==1) * 1
accidents.df$ALCHL_I_2 = (accidents.df$ALCHL_I==2) * 1
Data partition:
set.seed(2)
n = nrow(accidents.df)
train.idx = sample(1:n, n*0.6)
train = accidents.df[train.idx, ]
holdout=accidents.df[-train.idx, ]
We will run a neural network with 2 hidden nodes. Use hidden= with a vector of integers specifying the number of hidden nodes in each layer. The dummy variable “SUR_COND_9” is not used in the model, since it is perfectly correlated to other dummy variables.
nn <- neuralnet(
formula = factor(MAX_SEV_IR) ~
ALCHL_I_1 + PROFIL_I_R + VEH_INVL + SUR_COND_1 + SUR_COND_2
+ SUR_COND_3 + SUR_COND_4,
data = train,
hidden = 2)
plot(nn, rep="best")
Note that the categorical response variable “MAX_SEV_IR” must be converted to a factor.
Next, we create two confusion matrices, one for training data and one for holdout data. The purpose is to compare the results to see if they are similar. If yes, there is no serious overfitting issue. If the results based on the holdout is much worse, there is a serious overfitting issue.
The confusion matrix for training data is
train.prediction <- predict(nn, train)
train.prediction.class <- apply(train.prediction,1,which.max)-1
caret::confusionMatrix(factor(train.prediction.class), factor(train$MAX_SEV_IR))
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1 2
## 0 334 0 35
## 1 0 164 34
## 2 1 7 24
##
## Overall Statistics
##
## Accuracy : 0.8715
## 95% CI : (0.842, 0.8972)
## No Information Rate : 0.5593
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.7675
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: 0 Class: 1 Class: 2
## Sensitivity 0.9970 0.9591 0.25806
## Specificity 0.8674 0.9206 0.98419
## Pos Pred Value 0.9051 0.8283 0.75000
## Neg Pred Value 0.9957 0.9825 0.87831
## Prevalence 0.5593 0.2855 0.15526
## Detection Rate 0.5576 0.2738 0.04007
## Detection Prevalence 0.6160 0.3306 0.05342
## Balanced Accuracy 0.9322 0.9398 0.62113
The confusion matrix for holdout data is
holdout.prediction <- predict(nn, holdout)
holdout.prediction.class <-apply(holdout.prediction,1,which.max)-1
caret::confusionMatrix(factor(holdout.prediction.class), factor(holdout$MAX_SEV_IR))
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1 2
## 0 216 0 20
## 1 0 123 23
## 2 0 5 13
##
## Overall Statistics
##
## Accuracy : 0.88
## 95% CI : (0.8441, 0.9102)
## No Information Rate : 0.54
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.7851
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: 0 Class: 1 Class: 2
## Sensitivity 1.0000 0.9609 0.2321
## Specificity 0.8913 0.9154 0.9855
## Pos Pred Value 0.9153 0.8425 0.7222
## Neg Pred Value 1.0000 0.9803 0.8874
## Prevalence 0.5400 0.3200 0.1400
## Detection Rate 0.5400 0.3075 0.0325
## Detection Prevalence 0.5900 0.3650 0.0450
## Balanced Accuracy 0.9457 0.9382 0.6088
For numerical features (predictors) and quantitative responses, use the z-scores (or called the standardized scores) or convert them into the range [0, 1] by applying the transformation \(x' = \frac{x-min}{max-min}\). For each nominal categorical feature of \(m\) categories, use \(m-1\) dummy variables. For each ordinal categorical feature, map its categories to appropriate numerical values between 0 and 1. For example, if you have a feature “letter_grade”, the grades “A”, “B”, “C”, “D” and “F” can be mapped to 1.00, 0.75, 0.50, 0.25, and 0, respectively, and the R code is
# Assume your data frame called D has an ordinal column called letter_grade.
# You can create a new column by doing
D$letter_grade_numeric = dplyr::case_when(
letter_grade=="A"~1,
letter_grade=="B"~0.75,
letter_grade=="C"~0.5,
letter_grade=="D"~0.25,
.default=0
)
Now, we create neural networks for the iris data using Species as the response.
# use preProcess() from the caret package to standardize numerical variables.
# This function creates a recipe for normalizing. The same recipe is used for normalizing the
# numeric variables in holdout set as well.
set.seed(314)
n = nrow(iris)
train_idx <- sample(n, n * 0.8)
train <- iris[train_idx, ]
holdout <- iris[-train_idx, ]
# Normalize the training set to get a recipe
norm.values = preProcess(train, method = "range") # (x-min)/(max - min)
# Normalize both training and holdout data.
train.norm <- predict(norm.values, train)
holdout.norm <- predict(norm.values, holdout)
# Train a neural network.
nn <- neuralnet(
formula = Species ~ .,
data = train.norm,
hidden = c(2,1, 5), # Two hidden layers with 2 and 1 neurons, respectively
rep = 5, # the number of repetitions for the neural network's training, also known as Epoch.
# The epoch with the least error is reported.
# Use code: which.min(nn$result.matrix[1,]) to get the best epoch.
# The results are obtained by nn$result.matrix[, k], assuming the best epoch is k.
linear.output = FALSE
)
## Warning: Algorithm did not converge in 1 of 5 repetition(s) within the stepmax.
nn$result.matrix # Which epoch has the least error? 4th epoch.
## [,1] [,2] [,3]
## error 1.001615e+00 1.002009e+00 1.005104e+00
## reached.threshold 9.571835e-03 8.721278e-03 9.385633e-03
## steps 8.144100e+04 1.006900e+04 1.425000e+03
## Intercept.to.1layhid1 -2.989295e+00 -3.443804e+00 4.846014e-01
## Sepal.Length.to.1layhid1 -5.688573e-01 -3.190591e-01 2.277928e+00
## Sepal.Width.to.1layhid1 2.703880e-02 -2.128238e+00 -1.945185e+00
## Petal.Length.to.1layhid1 1.819080e+00 3.059828e+00 -2.743519e+00
## Petal.Width.to.1layhid1 3.603713e+00 3.334217e+00 -2.049547e+00
## Intercept.to.1layhid2 -3.129371e+00 8.021493e+00 2.328200e+00
## Sepal.Length.to.1layhid2 1.609274e-01 1.258237e+00 -3.192084e-01
## Sepal.Width.to.1layhid2 -1.765114e+00 -3.102604e+00 6.188162e+00
## Petal.Length.to.1layhid2 3.316089e+00 -5.469046e+00 -5.153139e+00
## Petal.Width.to.1layhid2 1.571098e+00 -5.155809e+00 -3.329929e+00
## Intercept.to.2layhid1 2.602505e+00 -1.584641e+00 1.923935e+00
## 1layhid1.to.2layhid1 -4.043594e+00 4.964020e+00 -9.495649e+00
## 1layhid2.to.2layhid1 -2.860173e+00 -1.958993e+00 -4.240781e+00
## Intercept.to.3layhid1 2.706306e+00 3.776164e+00 1.019637e+00
## 2layhid1.to.3layhid1 -1.650936e+01 -1.961648e+01 -3.182215e+00
## Intercept.to.3layhid2 -3.541384e+00 -5.066446e+01 -2.659779e+00
## 2layhid1.to.3layhid2 -3.994243e+00 9.491443e+02 4.562151e+00
## Intercept.to.3layhid3 -2.969237e+00 -3.884196e+00 -5.289218e+00
## 2layhid1.to.3layhid3 1.621150e+01 8.188528e+00 9.976585e+00
## Intercept.to.3layhid4 -6.718171e+00 -5.099584e+01 -7.570695e+00
## 2layhid1.to.3layhid4 1.096006e+01 9.554961e+02 1.420714e+02
## Intercept.to.3layhid5 -3.191871e+00 2.778127e+00 1.226080e+00
## 2layhid1.to.3layhid5 1.757565e+01 -8.289259e+00 -3.868800e+00
## Intercept.to.setosa -3.829634e+02 -1.499857e+00 -1.146594e-02
## 3layhid1.to.setosa -5.795981e+03 1.217913e+01 2.237267e+01
## 3layhid2.to.setosa -5.559741e+03 -8.821880e+01 -1.144389e+01
## 3layhid3.to.setosa -9.119178e+01 -3.325153e+01 -3.606054e+01
## 3layhid4.to.setosa 6.567385e+02 -8.849988e+01 -1.389433e+02
## 3layhid5.to.setosa -1.117092e+02 2.972396e+01 2.170046e+01
## Intercept.to.versicolor -5.915553e-01 7.281752e-01 -2.032978e+00
## 3layhid1.to.versicolor -1.457102e+03 -1.880853e+02 -5.875699e-01
## 3layhid2.to.versicolor -4.622684e+00 6.662514e+00 -1.745097e-01
## 3layhid3.to.versicolor 2.012961e+02 -2.030359e+02 -6.640448e+01
## 3layhid4.to.versicolor -2.797868e+02 7.878926e+00 2.801523e+01
## 3layhid5.to.versicolor 5.951978e+01 1.911285e+02 -3.555235e+00
## Intercept.to.virginica 3.321998e-01 -5.026696e+00 -1.225240e+00
## 3layhid1.to.virginica 1.822657e+03 -8.022399e+02 -2.356108e+01
## 3layhid2.to.virginica -3.082160e+00 1.083974e-01 1.765117e+00
## 3layhid3.to.virginica -7.349943e+01 1.884257e+02 6.664936e+01
## 3layhid4.to.virginica -6.790256e+03 3.584184e-01 1.474553e+00
## 3layhid5.to.virginica -4.562357e+01 -1.701733e+02 -4.694895e+01
## [,4]
## error 1.005804e+00
## reached.threshold 9.671142e-03
## steps 1.729000e+03
## Intercept.to.1layhid1 2.193635e+00
## Sepal.Length.to.1layhid1 1.854520e+00
## Sepal.Width.to.1layhid1 -1.383546e+00
## Petal.Length.to.1layhid1 -2.984769e+00
## Petal.Width.to.1layhid1 -2.454480e+00
## Intercept.to.1layhid2 1.846848e+00
## Sepal.Length.to.1layhid2 8.518098e-02
## Sepal.Width.to.1layhid2 7.120449e+00
## Petal.Length.to.1layhid2 -6.689385e+00
## Petal.Width.to.1layhid2 -3.155027e+00
## Intercept.to.2layhid1 -2.724231e+00
## 1layhid1.to.2layhid1 3.826678e+00
## 1layhid2.to.2layhid1 4.550591e+00
## Intercept.to.3layhid1 2.966600e+01
## 2layhid1.to.3layhid1 -3.935310e+01
## Intercept.to.3layhid2 3.032564e+00
## 2layhid1.to.3layhid2 -1.578746e+01
## Intercept.to.3layhid3 -2.005807e+00
## 2layhid1.to.3layhid3 7.275212e+00
## Intercept.to.3layhid4 -1.050371e+00
## 2layhid1.to.3layhid4 2.012009e+00
## Intercept.to.3layhid5 8.554065e-01
## 2layhid1.to.3layhid5 5.203757e-01
## Intercept.to.setosa -1.501191e+00
## 3layhid1.to.setosa -1.476478e+02
## 3layhid2.to.setosa -1.307848e+02
## 3layhid3.to.setosa 2.241865e+01
## 3layhid4.to.setosa 2.211282e+01
## 3layhid5.to.setosa 8.926680e-02
## Intercept.to.versicolor 1.047101e+00
## 3layhid1.to.versicolor 2.645109e+01
## 3layhid2.to.versicolor -9.304316e+01
## 3layhid3.to.versicolor -3.570186e+00
## 3layhid4.to.versicolor -2.496196e+00
## 3layhid5.to.versicolor -1.666424e+00
## Intercept.to.virginica -6.198779e-02
## 3layhid1.to.virginica -1.368415e+00
## 3layhid2.to.virginica 9.167592e+01
## 3layhid3.to.virginica -4.643261e+01
## 3layhid4.to.virginica 3.158593e-01
## 3layhid5.to.virginica -6.994342e-02
plot(nn, rep = "best") # For help, use ?plot.nn
# Identify the epoch with the least error
best_epoch <- which.min(nn$result.matrix[1,])
# Obtain unique levels of the target variable
target_levels <- levels(train.norm$Species)
# Predictions based on the best training epoch
pred_valid_prob <- predict(nn,
newdata = holdout.norm,
rep = best_epoch
)
# Assign column names based on target variable levels
colnames(pred_valid_prob) <- target_levels
# Convert probabilities to class labels
pred_valid <- apply(pred_valid_prob, 1, which.max)
pred_valid_lbls <- target_levels[pred_valid]
# Output the predicted class labels
pred_valid_lbls
## [1] "setosa" "setosa" "setosa" "setosa" "setosa"
## [6] "setosa" "setosa" "setosa" "setosa" "setosa"
## [11] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
## [16] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
## [21] "versicolor" "versicolor" "versicolor" "virginica" "virginica"
## [26] "virginica" "virginica" "virginica" "virginica" "virginica"
# Compute confusion matrix for the holdout data
confusionMatrix(factor(pred_valid_lbls), factor(holdout.norm$Species))
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 13 0
## virginica 0 0 7
##
## Overall Statistics
##
## Accuracy : 1
## 95% CI : (0.8843, 1)
## No Information Rate : 0.4333
## P-Value [Acc > NIR] : 1.273e-11
##
## Kappa : 1
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 1.0000 1.0000
## Specificity 1.0000 1.0000 1.0000
## Pos Pred Value 1.0000 1.0000 1.0000
## Neg Pred Value 1.0000 1.0000 1.0000
## Prevalence 0.3333 0.4333 0.2333
## Detection Rate 0.3333 0.4333 0.2333
## Detection Prevalence 0.3333 0.4333 0.2333
## Balanced Accuracy 1.0000 1.0000 1.0000
When numerical features are highly skewed, as in business applications, it’s suggested to take a log-transformation before normalization.
When using neural networks for regression, the range of the output activation function should be consistent with the range of the quantitative response variable, or no output activation function is used (setting linear.output = TRUE). It is again suggested that variables be pre-processed (start with training data) with the function “preProcess” so that the range of each numeric variable is from 0 to 1, or from -1 to 1, or around 0.
When making prediction for a quantitative response variable, the predicted values should be convert back to the original scale.
The following is to use a neural network for regressing Sepal.Length with all other variables in the iris data as predictors. Since neural networks only accept numeric predictors, the Species variable needs to be dummified and only 2 of these dummy variables are used.
iris1 = iris # Create a copy for iris data
# Create 3 dummy variables to replace the Species variables. Only two of these will be used.
iris1$Species_setosa = (iris1$Species=="setosa")*1
iris1$Species_versicolor = (iris1$Species=="versicolor")*1
iris1$Species_virginica = (iris1$Species=="virginica")*1
head(iris1) # Dummy variables are successfully created
## Sepal.Length Sepal.Width Petal.Length Petal.Width Species Species_setosa
## 1 5.1 3.5 1.4 0.2 setosa 1
## 2 4.9 3.0 1.4 0.2 setosa 1
## 3 4.7 3.2 1.3 0.2 setosa 1
## 4 4.6 3.1 1.5 0.2 setosa 1
## 5 5.0 3.6 1.4 0.2 setosa 1
## 6 5.4 3.9 1.7 0.4 setosa 1
## Species_versicolor Species_virginica
## 1 0 0
## 2 0 0
## 3 0 0
## 4 0 0
## 5 0 0
## 6 0 0
# Partition data into 90% training and 10% holdout sets
set.seed(123)
n = nrow(iris1)
train_idx <- sample(nrow(iris1), n * 0.9)
train <- iris1[train_idx, ]
valid <- iris1[-train_idx, ]
# Normalization of numerical data using preProcess() from the caret package.
# This function creates a recipe used for normalizing ALL numerical data (including response)
# in both training and holdout sets
norm.values = caret::preProcess(train, method = "range") # Normalize to [0,1]
# Note the function predict() does the normalizing job.
train.norm <- predict(norm.values, train)
valid.norm <- predict(norm.values, valid)
# Fit a neural network
nn <- neuralnet(
formula = Sepal.Length ~ Petal.Length + Petal.Width + Species_setosa + Species_versicolor, # Use only 2 dummies
data = train.norm,
hidden = c(1, 2), # We use 2 hidden layers with layer 1 having 1 node and layer 2 having 2 nodes
linear.output = TRUE # For regression, set "linear.output" to TRUE; set it to FALSE for classification
)
plot(nn, rep = "best")
# Predictions
pred.train = predict(nn, newdata = train.norm) # based on normalized training data
head(pred.train)
## [,1]
## 14 0.1537110
## 50 0.1866937
## 118 0.9284185
## 43 0.1750122
## 150 0.5091412
## 148 0.5350192
pred.valid = predict(nn, newdata = valid.norm) # based on normalized holdout data
head(pred.valid)
## [,1]
## 1 0.1866937
## 18 0.1865793
## 28 0.1990731
## 33 0.1991943
## 48 0.1866937
## 55 0.5237451
# Convert predicted values to the original scale, since the response variable has been normalized.
Min = min(train$Sepal.Length) # Keep in mind that we normalized data
# based on the minimums and maximums of the train data
Max = max(train$Sepal.Length)
pred.ori.valid = as.numeric(pred.valid)*(Max-Min) + Min
head(pred.ori.valid)
## [1] 4.972097 4.971686 5.016663 5.017100 4.972097 6.185482
plot(x = valid$Sepal.Length, y = pred.ori.valid)
# Display accuracy measures
forecast::accuracy(pred.ori.valid, valid$Sepal.Length)
## ME RMSE MAE MPE MAPE
## Test set 0.1508507 0.2667496 0.2400849 2.207179 4.094318
If we want to set a category of a particular nominal variable as the reference, we can use the relevel() function (not for ordinal variable!!!):
D= iris
D$Species=relevel(D$Species, ref = "versicolor")
# or do
levels(D$Species)[1] = "versicolor" # set the reference category to "versicolor"
Commonly used deep learning methods include:
Convolutional Neural Networks (CNNs): Particularly effective for image recognition and classification tasks. CNNs utilize convolutional layers to extract features from input data.
Recurrent Neural Networks (RNNs): Suitable for sequential data such as time series, text, and speech. RNNs have loops that allow information to persist, enabling them to capture temporal dependencies.
Transformers: Introduced in the context of natural language processing (NLP), transformers have shown exceptional performance in tasks such as language translation, text generation, and sentiment analysis. They rely on self-attention mechanisms to capture contextual information from input sequences.
Neural networks have high tolerance to noisy data and the ability to capture highly complicated relationship between the predictors and an outcome variable.
Neural networks do not have a built-in variable selection mechanism. This means that there is a need to use variables pre-selected by other models or methods, such as decision trees and PCA.
Neural networks may obtain weights that lead to a local optimum thus do not provide the best fit to the training data.
Neural networks are relatively heavy on computation time.
Neural networks are less interpretable than other supervised learning methods such as decision trees. Thus, the neural network is called a black-box model while the decision tree is called a white-box model.
A/B testing is a commonly used technique in practice to compare two versions (A and B) of a webpage, app, or marketing campaign to determine which one performs better in terms of predefined metrics such as click-through rates, conversion rates, or revenue. Here’s an example of how A/B testing is implemented in practice:
Let’s say you are a product manager at an e-commerce company, and you want to test whether changing the color of the “Buy Now” button on your website will increase the conversion rate. Currently, the button is green, and you want to test whether changing it to red will have a positive impact.
Define Hypothesis: First, you need to define your hypothesis. In this case, your null hypothesis (\(H_0\)) could be that changing the button color has no effect on the conversion rate, while your alternative hypothesis (\(H_1\)) could be that changing the button color to red will increase the conversion rate.
Setup Experiment: Divide your website visitors randomly into two groups: Group A, which sees the green button (the control group), and Group B, which sees the red button (the treatment group). Ensure that the groups are similar in terms of demographics and other relevant factors.
Collect Data: Monitor the behavior of both groups over a set period, tracking metrics such as the number of clicks on the “Buy Now” button and the number of completed purchases (conversions). Ensure that you collect enough data to make statistically significant conclusions.
Analyze Results: Once you have collected sufficient data, analyze the results to see if there is a significant difference in conversion rates between the two groups. You can use statistical methods such as hypothesis testing (e.g., z-test) to determine whether the difference is statistically significant.
Draw Conclusions: Based on the analysis, determine whether to accept or reject the null hypothesis. If the conversion rate for the red button is significantly higher than the green button, you can conclude that changing the button color has a positive impact on the conversion rate.
Implement Changes: If the results of the A/B test are conclusive and positive, you can implement the change (i.e., permanently change the button color to red) for all website visitors.
Monitor Performance: Continuously monitor the performance of the red button to ensure that the improvement in conversion rate is sustained over time. You may also want to conduct further A/B tests to optimize other elements of the website.
In summary:
A/B testing is a straightforward experimental method used to compare two or more variations of a marketing campaign or website to determine which one performs better in terms of predefined metrics such as click-through rates, conversion rates, or revenue.
In A/B testing, users are randomly assigned to different versions (A and B) of the campaign or website, and their behavior is measured to assess the impact of the variations.
A/B testing is typically used to evaluate the overall effectiveness of a marketing campaign or website design by comparing the performance of different versions in terms of the target metric.
Uplift modeling, also known as persuasion modeling or incremental modeling, is a predictive modeling technique used to identify the individuals or segments within a target audience who are most likely to respond positively to a marketing intervention or treatment.
Uplift modeling aims to predict the incremental impact of a marketing campaign or treatment on individual customers’ behavior, such as whether they will make a purchase, subscribe to a service, or respond to a promotional offer.
Unlike A/B testing, which evaluates the average treatment effect across the entire population, uplift modeling focuses on identifying the subset of individuals who are most influenced by the treatment, allowing marketers to target their efforts more effectively by avoiding targeting those who would have acted positively without the intervention.
We focus on uplift modeling.
Here’s an outline of the steps to calculate uplift scores:
Train a Model: Train a machine learning model on historical data, where you have both treatment and control groups, as well as the outcomes of interest (e.g., response, conversion).
Predict Probabilities: Use the trained model to predict the probability of the desired outcome for both the treatment and control groups separately. You will obtain two sets of predicted probabilities: one for the treatment group and one for the control group.
Calculate Uplift Scores: For each observation, subtract the predicted probability for the control group from the predicted probability for the treatment group. This yields the uplift score for each observation, indicating the difference in the probability of the desired outcome between the treatment and control groups. P(purchase|20, M, 60000, Treat) = 0.67; P(purchase|20, M, 60000, Control)=0.554
We use simulated data to demonstrate the steps:
# Install and load required libraries
library(caret)
# Generate synthetic data
set.seed(123)
n <- 1000
treatment <- sample(0:1, n, replace = TRUE) # Treatment indicator (0 = control, 1 = treatment)
baseline_features <- matrix(runif(n * 5), ncol = 5) # Baseline features
# Binary response variable
probability = apply(cbind(treatment, baseline_features), 1,
function(x) 1/(1+exp(-(-1+5*x[1]-2*x[2]+3*x[3]-4*x[4]+5*x[5]-6*x[6]))))
response <- rbinom(n, 1, probability)
# Combine data into a data frame
data <- data.frame(treatment = treatment, baseline_features, response = response)
Train logistic model:
uplift_model <- glm(response~., family="binomial", data=data)
# All the following code calculates the uplift scores
data$treatment <- 1 # set the treatment column to 1
# Predict probabilities for treatment group
pred1 <- predict(uplift_model, newdata = data, type = "response")
data$treatment <- 0 # set the treatment column to 0
# Predict probabilities for control group
pred0 <- predict(uplift_model, newdata = data, type = "response")
# Calculate uplift scores
upliftResult <- data.frame(pred1,pred0, uplift = pred1 - pred0)
res1 = arrange(upliftResult, by = -uplift)
res1
## pred1 pred0 uplift
## 419 0.917997575 8.104588e-02 0.836951695
## 173 0.917525770 8.058153e-02 0.836944242
## 231 0.920202800 8.328249e-02 0.836920311
## 183 0.920636745 8.373591e-02 0.836900830
## 819 0.920648407 8.374816e-02 0.836900244
## 369 0.916027489 7.913852e-02 0.836888965
## 403 0.915909046 7.902645e-02 0.836882592
## 445 0.921072888 8.419620e-02 0.836876686
## 15 0.915734904 7.886221e-02 0.836872698
## 24 0.921186651 8.431702e-02 0.836869627
## 343 0.915515521 7.865617e-02 0.836859353
## 931 0.922029242 8.522186e-02 0.836807385
## 168 0.923065319 8.635910e-02 0.836706224
## 904 0.923141447 8.644375e-02 0.836697694
## 786 0.923803515 8.718646e-02 0.836617060
## 846 0.923874027 8.726624e-02 0.836607782
## 308 0.924071672 8.749061e-02 0.836581064
## 760 0.924094114 8.751615e-02 0.836577963
## 567 0.912146568 7.561067e-02 0.836535897
## 829 0.924539238 8.802562e-02 0.836513622
## 636 0.925199330 8.879121e-02 0.836408116
## 955 0.911011138 7.463195e-02 0.836379185
## 768 0.910242634 7.398243e-02 0.836260207
## 192 0.910209403 7.395457e-02 0.836254832
## 85 0.910188066 7.393670e-02 0.836251370
## 989 0.909766083 7.358476e-02 0.836181320
## 424 0.909764692 7.358361e-02 0.836181084
## 105 0.909578868 7.342959e-02 0.836149275
## 353 0.908727968 7.273172e-02 0.835996249
## 373 0.908702631 7.271112e-02 0.835991509
## 767 0.927390047 9.142201e-02 0.835968038
## 122 0.928272730 9.252290e-02 0.835749830
## 154 0.928287659 9.254173e-02 0.835745929
## 890 0.928370224 9.264599e-02 0.835724230
## 59 0.928621487 9.296463e-02 0.835656860
## 660 0.906555798 7.100331e-02 0.835552484
## 415 0.906279196 7.078852e-02 0.835490674
## 286 0.929469083 9.405454e-02 0.835414547
## 306 0.929553624 9.416454e-02 0.835389085
## 920 0.930074095 9.484702e-02 0.835227071
## 678 0.904463080 6.940676e-02 0.835056320
## 458 0.904038516 6.909070e-02 0.834947812
## 687 0.903716279 6.885254e-02 0.834863738
## 938 0.903624735 6.878515e-02 0.834839585
## 877 0.902997266 6.832640e-02 0.834670868
## 82 0.902402192 6.789637e-02 0.834505823
## 323 0.902252206 6.778875e-02 0.834463460
## 525 0.902137899 6.770693e-02 0.834430969
## 880 0.901933020 6.756073e-02 0.834372293
## 249 0.932480434 9.812480e-02 0.834355630
## 187 0.901743714 6.742614e-02 0.834317575
## 889 0.901675693 6.737790e-02 0.834297797
## 594 0.932721826 9.846519e-02 0.834256637
## 866 0.933073731 9.896534e-02 0.834108392
## 469 0.901006829 6.690678e-02 0.834100046
## 832 0.900455262 6.652270e-02 0.833932562
## 26 0.933824724 1.000486e-01 0.833776141
## 260 0.899485873 6.585713e-02 0.833628741
## 916 0.898243690 6.502146e-02 0.833222227
## 274 0.898217963 6.500436e-02 0.833213607
## 579 0.935111928 1.019572e-01 0.833154700
## 876 0.935139390 1.019987e-01 0.833140706
## 529 0.935297199 1.022375e-01 0.833059686
## 106 0.897613567 6.460475e-02 0.833008820
## 843 0.935492925 1.025352e-01 0.832957752
## 6 0.935601054 1.027003e-01 0.832900748
## 81 0.935661802 1.027933e-01 0.832868506
## 552 0.896838884 6.409891e-02 0.832739977
## 956 0.935978513 1.032806e-01 0.832697869
## 792 0.896398242 6.381432e-02 0.832583923
## 745 0.936504755 1.041000e-01 0.832404784
## 145 0.936612231 1.042688e-01 0.832343440
## 950 0.895622971 6.331903e-02 0.832303940
## 263 0.937037808 1.049423e-01 0.832095508
## 823 0.894569050 6.265659e-02 0.831912460
## 276 0.894485494 6.260460e-02 0.831880897
## 418 0.894374973 6.253594e-02 0.831839030
## 449 0.894115504 6.237529e-02 0.831740215
## 917 0.937850497 1.062512e-01 0.831599329
## 641 0.893143364 6.177983e-02 0.831363533
## 454 0.938593643 1.074749e-01 0.831118757
## 429 0.892426542 6.134718e-02 0.831079361
## 648 0.939103122 1.083291e-01 0.830774024
## 465 0.891517423 6.080613e-02 0.830711293
## 451 0.939487578 1.089821e-01 0.830505470
## 204 0.939740214 1.094152e-01 0.830324987
## 773 0.940229230 1.102628e-01 0.829966452
## 411 0.940358176 1.104883e-01 0.829869868
## 591 0.940428005 1.106108e-01 0.829817205
## 946 0.888927891 5.931031e-02 0.829617584
## 152 0.888228833 5.891759e-02 0.829311238
## 374 0.887553454 5.854252e-02 0.829010938
## 117 0.941944529 1.133350e-01 0.828609537
## 36 0.886164187 5.778405e-02 0.828380133
## 806 0.886140035 5.777102e-02 0.828369014
## 398 0.942351922 1.140883e-01 0.828263648
## 357 0.942514986 1.143924e-01 0.828122572
## 616 0.942655412 1.146555e-01 0.827999864
## 613 0.942685587 1.147122e-01 0.827973350
## 590 0.885129990 5.723058e-02 0.827899409
## 568 0.942912822 1.151408e-01 0.827771983
## 695 0.943000369 1.153068e-01 0.827693602
## 966 0.884030082 5.665208e-02 0.827378002
## 742 0.883876541 5.657214e-02 0.827304400
## 221 0.943639392 1.165316e-01 0.827107798
## 427 0.883450006 5.635110e-02 0.827098901
## 320 0.943867633 1.169750e-01 0.826892643
## 974 0.943875990 1.169913e-01 0.826884706
## 761 0.943970525 1.171759e-01 0.826794616
## 605 0.943985509 1.172052e-01 0.826780287
## 759 0.944033470 1.172991e-01 0.826734329
## 572 0.882449755 5.583865e-02 0.826611101
## 338 0.944483386 1.181871e-01 0.826296287
## 423 0.945148083 1.195222e-01 0.825625846
## 462 0.879918693 5.457770e-02 0.825340990
## 293 0.945593399 1.204326e-01 0.825160757
## 620 0.878834987 5.405293e-02 0.824782060
## 267 0.878699905 5.398813e-02 0.824711773
## 723 0.946224312 1.217450e-01 0.824479337
## 800 0.878240720 5.376888e-02 0.824471839
## 981 0.946372069 1.220562e-01 0.824315865
## 466 0.876894517 5.313495e-02 0.823759562
## 416 0.876540247 5.297029e-02 0.823569959
## 721 0.876220770 5.282255e-02 0.823398217
## 195 0.947225974 1.238845e-01 0.823341465
## 735 0.875739075 5.260115e-02 0.823137921
## 198 0.875680841 5.257450e-02 0.823106344
## 169 0.947813051 1.251716e-01 0.822641420
## 780 0.874588002 5.207857e-02 0.822509433
## 774 0.948099493 1.258088e-01 0.822290694
## 730 0.873820319 5.173503e-02 0.822085290
## 70 0.948702563 1.271704e-01 0.821532135
## 784 0.872717536 5.124836e-02 0.821469179
## 625 0.872554157 5.117693e-02 0.821377228
## 266 0.949102287 1.280883e-01 0.821013969
## 712 0.871728679 5.081866e-02 0.820910017
## 997 0.870793009 5.041779e-02 0.820375224
## 459 0.870120412 5.013298e-02 0.819987430
## 129 0.950420358 1.312054e-01 0.819214953
## 380 0.868605728 4.950167e-02 0.819104054
## 948 0.950750539 1.320087e-01 0.818741789
## 292 0.867819932 4.917954e-02 0.818640394
## 600 0.950922605 1.324311e-01 0.818491523
## 601 0.866034181 4.846074e-02 0.817573444
## 486 0.865525570 4.825931e-02 0.817266261
## 894 0.952054153 1.352732e-01 0.816780930
## 193 0.952098234 1.353863e-01 0.816711961
## 387 0.952324751 1.359700e-01 0.816354724
## 575 0.952350633 1.360370e-01 0.816313604
## 290 0.952491526 1.364029e-01 0.816088658
## 390 0.863562579 4.749520e-02 0.816067375
## 410 0.863056355 4.730151e-02 0.815754843
## 7 0.953178115 1.382126e-01 0.814965529
## 291 0.861578030 4.674354e-02 0.814834488
## 436 0.953401499 1.388112e-01 0.814590293
## 272 0.953435139 1.389018e-01 0.814533360
## 830 0.953612081 1.393800e-01 0.814232052
## 110 0.953688108 1.395865e-01 0.814101631
## 840 0.953753552 1.397647e-01 0.813988900
## 973 0.953783866 1.398473e-01 0.813936538
## 27 0.953789207 1.398619e-01 0.813927303
## 557 0.858453442 4.560053e-02 0.812852912
## 681 0.954778409 1.426121e-01 0.812166309
## 588 0.857060668 4.510629e-02 0.811954376
## 206 0.856528727 4.491993e-02 0.811608801
## 755 0.955116301 1.435751e-01 0.811541187
## 208 0.856380828 4.486834e-02 0.811512485
## 119 0.856108617 4.477366e-02 0.811334952
## 577 0.855876527 4.469321e-02 0.811183318
## 947 0.955449666 1.445374e-01 0.810912291
## 711 0.855397198 4.452782e-02 0.810869378
## 633 0.955566293 1.448769e-01 0.810689380
## 673 0.955628285 1.450580e-01 0.810570280
## 5 0.955656766 1.451414e-01 0.810515415
## 799 0.955740012 1.453855e-01 0.810354538
## 732 0.956417169 1.474006e-01 0.809016578
## 150 0.956522373 1.477184e-01 0.808803947
## 977 0.956546579 1.477917e-01 0.808754839
## 62 0.851317464 4.316112e-02 0.808156347
## 801 0.851151728 4.310710e-02 0.808044629
## 160 0.956940077 1.489933e-01 0.807946774
## 98 0.956974132 1.490982e-01 0.807875969
## 452 0.850659850 4.294745e-02 0.807712398
## 397 0.850189507 4.279573e-02 0.807393780
## 395 0.849943855 4.271684e-02 0.807227013
## 108 0.849603417 4.260792e-02 0.806995492
## 669 0.849459507 4.256202e-02 0.806897483
## 327 0.957443010 1.505563e-01 0.806886719
## 10 0.957765651 1.515755e-01 0.806190175
## 422 0.848001292 4.210158e-02 0.805899717
## 757 0.847525659 4.195320e-02 0.805572460
## 219 0.847213741 4.185637e-02 0.805357369
## 912 0.846924404 4.176689e-02 0.805157514
## 500 0.846863256 4.174802e-02 0.805115237
## 739 0.846320998 4.158131e-02 0.804739691
## 503 0.958437729 1.537412e-01 0.804696578
## 630 0.846026032 4.149109e-02 0.804534941
## 556 0.845665782 4.138135e-02 0.804284430
## 225 0.845628940 4.137016e-02 0.804258783
## 559 0.958703030 1.546123e-01 0.804090709
## 18 0.958840112 1.550661e-01 0.803773965
## 236 0.958869495 1.551638e-01 0.803705743
## 640 0.959211882 1.563098e-01 0.802902096
## 520 0.959453601 1.571286e-01 0.802324990
## 999 0.842757177 4.051287e-02 0.802244303
## 80 0.959488943 1.572490e-01 0.802239924
## 1 0.959499904 1.572864e-01 0.802213509
## 527 0.842426384 4.041604e-02 0.802010348
## 460 0.841102083 4.003220e-02 0.801069884
## 47 0.840736722 3.992737e-02 0.800809350
## 901 0.960156121 1.595554e-01 0.800600690
## 97 0.840278486 3.979654e-02 0.800481942
## 485 0.960247925 1.598778e-01 0.800370079
## 107 0.960321515 1.601372e-01 0.800184327
## 775 0.960685154 1.614306e-01 0.799254577
## 257 0.838354908 3.925507e-02 0.799099838
## 78 0.838234460 3.922157e-02 0.799012886
## 294 0.960836234 1.619738e-01 0.798862426
## 52 0.960899910 1.622038e-01 0.798696100
## 345 0.836966043 3.887169e-02 0.798094354
## 474 0.961168798 1.631820e-01 0.797986840
## 862 0.836776358 3.881981e-02 0.797956547
## 379 0.836118793 3.864086e-02 0.797477935
## 439 0.961654685 1.649783e-01 0.796676367
## 809 0.961656426 1.649848e-01 0.796671604
## 587 0.961744261 1.653136e-01 0.796430651
## 698 0.834601175 3.823303e-02 0.796368146
## 72 0.833948350 3.805978e-02 0.795888566
## 595 0.833918522 3.805190e-02 0.795866623
## 523 0.961985184 1.662219e-01 0.795763283
## 666 0.962021954 1.663614e-01 0.795660588
## 256 0.833251071 3.787617e-02 0.795374899
## 873 0.831769202 3.749078e-02 0.794278420
## 884 0.962741881 1.691377e-01 0.793604230
## 962 0.962785089 1.693071e-01 0.793477998
## 239 0.829923154 3.701966e-02 0.792903498
## 126 0.829905591 3.701522e-02 0.792890370
## 596 0.829308934 3.686506e-02 0.792443872
## 798 0.963351109 1.715571e-01 0.791794033
## 77 0.827571734 3.643352e-02 0.791138213
## 914 0.963629428 1.726845e-01 0.790944928
## 770 0.826167046 3.609061e-02 0.790076435
## 713 0.963926071 1.739019e-01 0.790024219
## 942 0.825900772 3.602621e-02 0.789874566
## 394 0.964021260 1.742960e-01 0.789725291
## 834 0.964239835 1.752074e-01 0.789032390
## 589 0.824550824 3.570256e-02 0.788848262
## 538 0.964302960 1.754724e-01 0.788830576
## 44 0.823859199 3.553859e-02 0.788320612
## 430 0.964701272 1.771619e-01 0.787539328
## 875 0.964763144 1.774272e-01 0.787335955
## 896 0.822140996 3.513651e-02 0.787004488
## 496 0.822010277 3.510622e-02 0.786904055
## 39 0.964950570 1.782353e-01 0.786715222
## 49 0.821477219 3.498316e-02 0.786494059
## 148 0.821098898 3.489625e-02 0.786202651
## 909 0.820551920 3.477121e-02 0.785780711
## 555 0.818039296 3.420608e-02 0.783833217
## 717 0.817265577 3.403506e-02 0.783230521
## 341 0.966005852 1.829204e-01 0.783085419
## 330 0.816460265 3.385852e-02 0.782601746
## 56 0.966155147 1.836024e-01 0.782552788
## 960 0.815000639 3.354230e-02 0.781458339
## 719 0.813222955 3.316358e-02 0.780059375
## 964 0.812627682 3.303830e-02 0.779589379
## 176 0.967015333 1.876283e-01 0.779387042
## 545 0.812344270 3.297893e-02 0.779365344
## 983 0.967048907 1.877889e-01 0.779260047
## 510 0.811615234 3.282697e-02 0.778788259
## 57 0.967257996 1.887948e-01 0.778463182
## 728 0.967284817 1.889246e-01 0.778360213
## 319 0.967731744 1.911128e-01 0.776618961
## 142 0.967767410 1.912895e-01 0.776477907
## 194 0.808000912 3.209002e-02 0.775910896
## 270 0.968137114 1.931400e-01 0.774997117
## 726 0.968401419 1.944841e-01 0.773917274
## 903 0.804905491 3.147972e-02 0.773425773
## 339 0.968734558 1.962042e-01 0.772530383
## 570 0.803749230 3.125650e-02 0.772492735
## 63 0.802869436 3.108833e-02 0.771781104
## 585 0.801873476 3.089970e-02 0.770973778
## 724 0.801763630 3.087900e-02 0.770884625
## 302 0.969443795 1.999652e-01 0.769478613
## 763 0.969460093 2.000532e-01 0.769406855
## 132 0.799455070 3.044915e-02 0.769005915
## 569 0.969588526 2.007498e-01 0.768838763
## 514 0.798347113 3.024622e-02 0.768100896
## 794 0.798280795 3.023414e-02 0.768046657
## 498 0.797964138 3.017657e-02 0.767787570
## 372 0.969833028 2.020887e-01 0.767744288
## 489 0.797805599 3.014781e-02 0.767657788
## 443 0.797427507 3.007940e-02 0.767348105
## 223 0.796415362 2.989748e-02 0.766517886
## 366 0.795763905 2.978130e-02 0.765982605
## 910 0.794465879 2.955193e-02 0.764913946
## 180 0.794356754 2.953278e-02 0.764823977
## 845 0.970496396 2.058096e-01 0.764686746
## 615 0.793890532 2.945116e-02 0.764439376
## 897 0.793497572 2.938264e-02 0.764114935
## 448 0.970634515 2.066010e-01 0.764033496
## 371 0.970686697 2.069015e-01 0.763785166
## 232 0.970724995 2.071226e-01 0.763602375
## 769 0.791915878 2.910936e-02 0.762806516
## 426 0.971263494 2.102802e-01 0.760983287
## 856 0.788609482 2.855084e-02 0.760058645
## 464 0.971491654 2.116462e-01 0.759845445
## 493 0.971583131 2.121987e-01 0.759384438
## 164 0.971697080 2.128908e-01 0.758806273
## 502 0.784895411 2.794319e-02 0.756952218
## 365 0.972095447 2.153450e-01 0.756750423
## 84 0.972188194 2.159243e-01 0.756263934
## 558 0.783245735 2.767974e-02 0.755565996
## 103 0.781959138 2.747694e-02 0.754482199
## 592 0.781939281 2.747383e-02 0.754465454
## 350 0.972667911 2.189689e-01 0.753699038
## 943 0.972711442 2.192493e-01 0.753462189
## 11 0.780510372 2.725132e-02 0.753259049
## 182 0.780501786 2.724999e-02 0.753251791
## 810 0.780098558 2.718772e-02 0.752910842
## 667 0.779931620 2.716200e-02 0.752769624
## 137 0.779491687 2.709440e-02 0.752397290
## 73 0.778514756 2.694521e-02 0.751569543
## 251 0.973069032 2.215790e-01 0.751490069
## 51 0.776648287 2.666369e-02 0.749984594
## 619 0.776440529 2.663264e-02 0.749807891
## 248 0.776351132 2.661929e-02 0.749731841
## 797 0.775680071 2.651944e-02 0.749160632
## 215 0.973503241 2.244729e-01 0.749030351
## 817 0.973514871 2.245514e-01 0.748963466
## 165 0.973551230 2.247972e-01 0.748754014
## 358 0.774857432 2.639782e-02 0.748459616
## 336 0.774478830 2.634213e-02 0.748136700
## 216 0.773962923 2.626654e-02 0.747696384
## 456 0.973787482 2.264072e-01 0.747380320
## 779 0.973936121 2.274315e-01 0.746504590
## 853 0.973979804 2.277343e-01 0.746245519
## 992 0.973988927 2.277976e-01 0.746191315
## 722 0.772185404 2.600863e-02 0.746176777
## 926 0.772059354 2.599049e-02 0.746068869
## 501 0.974035368 2.281205e-01 0.745914859
## 657 0.974143074 2.288728e-01 0.745270286
## 340 0.974454204 2.310731e-01 0.743381132
## 309 0.974463449 2.311391e-01 0.743324370
## 307 0.766950370 2.527115e-02 0.741679224
## 250 0.766010355 2.514210e-02 0.740868253
## 431 0.765931407 2.513131e-02 0.740800097
## 915 0.974869815 2.340768e-01 0.740792989
## 816 0.974895573 2.342655e-01 0.740630096
## 83 0.765114425 2.502004e-02 0.740094385
## 188 0.974994886 2.349956e-01 0.739999299
## 247 0.975101426 2.357837e-01 0.739317680
## 268 0.975457652 2.384566e-01 0.737001098
## 32 0.761460821 2.453146e-02 0.736929362
## 642 0.975545193 2.391224e-01 0.736422807
## 790 0.975877006 2.416791e-01 0.734197873
## 749 0.758218834 2.410989e-02 0.734108941
## 406 0.756858094 2.393620e-02 0.732921898
## 277 0.976097695 2.434091e-01 0.732688559
## 531 0.976117797 2.435679e-01 0.732549887
## 622 0.755872542 2.381156e-02 0.732060981
## 748 0.751386504 2.325635e-02 0.728130154
## 222 0.976818324 2.492291e-01 0.727589259
## 483 0.977000173 2.507406e-01 0.726259622
## 475 0.749013224 2.297040e-02 0.726042820
## 692 0.748753772 2.293946e-02 0.725814310
## 463 0.746985016 2.273016e-02 0.724254860
## 879 0.746811059 2.270972e-02 0.724101335
## 516 0.977305506 2.533188e-01 0.723986749
## 370 0.746465788 2.266925e-02 0.723796537
## 822 0.746342658 2.265484e-02 0.723687814
## 230 0.745155098 2.251658e-02 0.722638519
## 793 0.744941347 2.249182e-02 0.722449522
## 347 0.977717072 2.568764e-01 0.720840648
## 684 0.977783637 2.574609e-01 0.720322688
## 16 0.741692755 2.212051e-02 0.719572248
## 707 0.740740201 2.201334e-02 0.718726861
## 90 0.978109556 2.603606e-01 0.717748962
## 477 0.978147776 2.607048e-01 0.717442997
## 162 0.739140802 2.183511e-02 0.717305693
## 533 0.738525130 2.176707e-02 0.716758065
## 679 0.738028727 2.171243e-02 0.716316298
## 87 0.737651247 2.167102e-02 0.715980231
## 3 0.734220295 2.129985e-02 0.712920447
## 632 0.734196747 2.129733e-02 0.712899414
## 857 0.733866823 2.126214e-02 0.712604686
## 334 0.732231216 2.108889e-02 0.711142321
## 455 0.978969577 2.683252e-01 0.710644336
## 582 0.731266620 2.098769e-02 0.710278933
## 170 0.979056934 2.691608e-01 0.709896142
## 988 0.979077428 2.693575e-01 0.709719880
## 280 0.729296311 2.078313e-02 0.708513179
## 4 0.728069809 2.065725e-02 0.707412556
## 851 0.727551302 2.060437e-02 0.706946933
## 911 0.727139856 2.056254e-02 0.706577313
## 112 0.979467401 2.731554e-01 0.706311954
## 143 0.726757558 2.052379e-02 0.706233768
## 348 0.724759738 2.032298e-02 0.704436762
## 311 0.979697146 2.754420e-01 0.704255141
## 604 0.979764320 2.761176e-01 0.703646713
## 865 0.979852511 2.770095e-01 0.702843021
## 532 0.721649510 2.001592e-02 0.701633587
## 991 0.721137617 1.996603e-02 0.701171591
## 20 0.718816032 1.974194e-02 0.699074089
## 278 0.980321993 2.818533e-01 0.698468710
## 509 0.980529282 2.840447e-01 0.696484539
## 29 0.715686989 1.944556e-02 0.696241433
## 368 0.714885445 1.937065e-02 0.695514794
## 388 0.714825165 1.936503e-02 0.695460130
## 939 0.714287526 1.931504e-02 0.694972484
## 35 0.713433349 1.923599e-02 0.694197358
## 930 0.713073968 1.920287e-02 0.693871100
## 342 0.712194291 1.912213e-02 0.693072160
## 849 0.711589636 1.906691e-02 0.692522722
## 995 0.980954950 2.886505e-01 0.692304483
## 139 0.710207171 1.894151e-02 0.691265661
## 908 0.707429529 1.869304e-02 0.688736493
## 326 0.705840511 1.855295e-02 0.687287566
## 287 0.705466074 1.852015e-02 0.686945925
## 562 0.704185407 1.840859e-02 0.685776820
## 318 0.981900880 2.994245e-01 0.682476349
## 101 0.700119599 1.806056e-02 0.682059043
## 584 0.699036111 1.796936e-02 0.681066754
## 734 0.698455623 1.792076e-02 0.680534864
## 229 0.982211582 3.031362e-01 0.679075336
## 392 0.982421987 3.057011e-01 0.676720866
## 976 0.982474395 3.063466e-01 0.676127820
## 65 0.982476147 3.063682e-01 0.676107949
## 361 0.982540411 3.071634e-01 0.675376996
## 593 0.982742873 3.096953e-01 0.673047622
## 899 0.689837213 1.722009e-02 0.672617121
## 382 0.689269952 1.717530e-02 0.672094648
## 389 0.982832966 3.108350e-01 0.671997956
## 820 0.982871626 3.113266e-01 0.671545022
## 99 0.982976363 3.126661e-01 0.670310279
## 672 0.983096596 3.142177e-01 0.668878934
## 202 0.685132575 1.685340e-02 0.668279178
## 479 0.685042696 1.684650e-02 0.668196201
## 413 0.685015437 1.684440e-02 0.668171034
## 656 0.983256845 3.163091e-01 0.666947713
## 580 0.680800919 1.652510e-02 0.664275820
## 227 0.983489357 3.193925e-01 0.664096869
## 994 0.680331804 1.649007e-02 0.663841739
## 720 0.983537013 3.200317e-01 0.663505305
## 383 0.679581285 1.643422e-02 0.663147061
## 314 0.678267710 1.633710e-02 0.661930607
## 331 0.678209314 1.633280e-02 0.661876510
## 43 0.677301617 1.626617e-02 0.661035451
## 921 0.675590928 1.614157e-02 0.659449360
## 544 0.672303557 1.590570e-02 0.656397861
## 490 0.670878048 1.580484e-02 0.655073204
## 235 0.984358121 3.314512e-01 0.652906917
## 526 0.668025054 1.560554e-02 0.652419511
## 128 0.667222141 1.555006e-02 0.651672085
## 945 0.666739763 1.551684e-02 0.651222918
## 161 0.664539391 1.536654e-02 0.649172852
## 925 0.664363535 1.535461e-02 0.649008926
## 618 0.662575376 1.523400e-02 0.647341379
## 634 0.662088353 1.520136e-02 0.646886991
## 159 0.661205023 1.514241e-02 0.646062617
## 528 0.660767032 1.511328e-02 0.645653748
## 284 0.659050880 1.499988e-02 0.644050995
## 808 0.658574466 1.496860e-02 0.643605864
## 953 0.985069267 3.420040e-01 0.643065283
## 299 0.657572081 1.490306e-02 0.642669021
## 301 0.985110075 3.426295e-01 0.642480589
## 924 0.656488991 1.483266e-02 0.641656330
## 312 0.654426004 1.469976e-02 0.639726240
## 671 0.653892267 1.466563e-02 0.639226635
## 190 0.652645370 1.458630e-02 0.638059074
## 716 0.651595640 1.451994e-02 0.637075704
## 842 0.651021742 1.448382e-02 0.636537921
## 367 0.985561004 3.496931e-01 0.635867857
## 332 0.985580361 3.500027e-01 0.635577623
## 298 0.985660813 3.512953e-01 0.634365538
## 114 0.985742003 3.526091e-01 0.633132858
## 450 0.985752895 3.527861e-01 0.632966759
## 905 0.985768063 3.530329e-01 0.632735148
## 94 0.646179960 1.418369e-02 0.631996266
## 638 0.640233953 1.382593e-02 0.626408022
## 885 0.639136092 1.376114e-02 0.625374956
## 683 0.986241776 3.609134e-01 0.625328385
## 384 0.986268893 3.613749e-01 0.624893963
## 835 0.986328836 3.623993e-01 0.623929559
## 470 0.637395728 1.365921e-02 0.623736520
## 417 0.636668217 1.361688e-02 0.623051335
## 153 0.986483655 3.650714e-01 0.621412274
## 48 0.633385097 1.342792e-02 0.619957175
## 855 0.633337686 1.342522e-02 0.619912468
## 420 0.632891559 1.339980e-02 0.619491756
## 447 0.631975305 1.334779e-02 0.618627510
## 608 0.631556258 1.332409e-02 0.618232165
## 825 0.628811254 1.317013e-02 0.615641124
## 203 0.986874856 3.719984e-01 0.614876473
## 446 0.626750673 1.305601e-02 0.613694661
## 621 0.986956207 3.734713e-01 0.613484915
## 985 0.625618542 1.299384e-02 0.612624705
## 935 0.987009770 3.744473e-01 0.612562431
## 481 0.625524631 1.298870e-02 0.612535935
## 517 0.624468063 1.293103e-02 0.611537033
## 788 0.987141957 3.768776e-01 0.610264338
## 781 0.987309648 3.800055e-01 0.607304183
## 428 0.618033078 1.258657e-02 0.605446512
## 731 0.615120215 1.243435e-02 0.602685864
## 92 0.615106482 1.243364e-02 0.602672843
## 598 0.615094479 1.243302e-02 0.602661462
## 824 0.615022057 1.242926e-02 0.602592796
## 504 0.987580654 3.851693e-01 0.602411352
## 982 0.614488921 1.240166e-02 0.602087261
## 611 0.987606334 3.856658e-01 0.601940577
## 965 0.613931238 1.237287e-02 0.601558370
## 727 0.613690166 1.236045e-02 0.601329720
## 940 0.613524493 1.235192e-02 0.601172574
## 288 0.987692080 3.873326e-01 0.600359525
## 704 0.611477190 1.224713e-02 0.599230060
## 993 0.987775165 3.889611e-01 0.598814039
## 23 0.987787698 3.892080e-01 0.598579741
## 156 0.987811075 3.896692e-01 0.598141893
## 421 0.987819540 3.898365e-01 0.597983077
## 147 0.609348427 1.213931e-02 0.597209114
## 491 0.987890704 3.912483e-01 0.596642402
## 258 0.607574508 1.205034e-02 0.595524165
## 744 0.988015725 3.937531e-01 0.594262672
## 701 0.606071696 1.197559e-02 0.594096110
## 241 0.605161787 1.193059e-02 0.593231194
## 275 0.603523102 1.185008e-02 0.591673026
## 407 0.603092409 1.182902e-02 0.591263387
## 803 0.988206959 3.976458e-01 0.590561204
## 33 0.601801012 1.176616e-02 0.590034852
## 13 0.601211807 1.173761e-02 0.589474195
## 512 0.600586090 1.170739e-02 0.588878705
## 360 0.600196994 1.168864e-02 0.588508359
## 264 0.988322318 4.000307e-01 0.588291648
## 691 0.599815632 1.167029e-02 0.588145339
## 902 0.988386737 4.013747e-01 0.587012048
## 244 0.598252848 1.159549e-02 0.586657361
## 586 0.988439745 4.024873e-01 0.585952435
## 670 0.595530099 1.146651e-02 0.584063591
## 540 0.988556844 4.049667e-01 0.583590113
## 31 0.988581237 4.054870e-01 0.583094238
## 694 0.594035064 1.139641e-02 0.582638655
## 75 0.988620945 4.063367e-01 0.582284228
## 617 0.593612909 1.137671e-02 0.582236203
## 259 0.593531012 1.137289e-02 0.582158123
## 324 0.988662962 4.072397e-01 0.581423291
## 349 0.589824672 1.120169e-02 0.578622986
## 425 0.988827284 4.108091e-01 0.578018231
## 310 0.588710309 1.115080e-02 0.577559506
## 138 0.587642960 1.110232e-02 0.576540640
## 599 0.988905547 4.125307e-01 0.576374803
## 499 0.583727066 1.092654e-02 0.572800531
## 304 0.989078120 4.163776e-01 0.572700492
## 690 0.583536075 1.091804e-02 0.572618030
## 907 0.583294786 1.090733e-02 0.572387457
## 610 0.582734003 1.088247e-02 0.571851532
## 652 0.581444973 1.082558e-02 0.570619392
## 844 0.580198785 1.077091e-02 0.569427878
## 631 0.989274772 4.208480e-01 0.568426790
## 60 0.579117152 1.072371e-02 0.568393442
## 789 0.577284273 1.064427e-02 0.566639998
## 362 0.576317048 1.060263e-02 0.565714421
## 548 0.989510731 4.263378e-01 0.563172948
## 847 0.573490202 1.048197e-02 0.563008230
## 827 0.989580193 4.279807e-01 0.561599447
## 37 0.571905250 1.041501e-02 0.561490243
## 524 0.570630545 1.036150e-02 0.560269043
## 346 0.567513088 1.023195e-02 0.557281133
## 391 0.566536309 1.019174e-02 0.556344569
## 89 0.989834701 4.341083e-01 0.555726391
## 484 0.564981658 1.012810e-02 0.554853557
## 511 0.564471280 1.010731e-02 0.554363974
## 207 0.558133363 9.853005e-03 0.548280357
## 363 0.558018024 9.848444e-03 0.548169580
## 488 0.557430669 9.825251e-03 0.547605418
## 563 0.557026452 9.809325e-03 0.547217127
## 967 0.990202592 4.432765e-01 0.546926103
## 649 0.555935498 9.766484e-03 0.546169014
## 96 0.990253429 4.445734e-01 0.545680030
## 140 0.990293971 4.456130e-01 0.544680972
## 200 0.554368213 9.705298e-03 0.544662915
## 886 0.550802466 9.567657e-03 0.541234809
## 738 0.990502315 4.510318e-01 0.539470491
## 166 0.990521422 4.515353e-01 0.538986151
## 958 0.546022388 9.386475e-03 0.536635913
## 141 0.545905202 9.382081e-03 0.536523121
## 863 0.990620943 4.541755e-01 0.536445483
## 522 0.990627673 4.543551e-01 0.536272574
## 539 0.543728394 9.300850e-03 0.534427544
## 867 0.990708153 4.565141e-01 0.534194035
## 854 0.543422158 9.289483e-03 0.534132675
## 120 0.990738659 4.573378e-01 0.533400858
## 157 0.538975999 9.126128e-03 0.529849871
## 534 0.538857059 9.121800e-03 0.529735259
## 255 0.537343930 9.066939e-03 0.528276991
## 564 0.535880513 9.014214e-03 0.526866299
## 393 0.991046910 4.658275e-01 0.525219458
## 578 0.533444294 8.927162e-03 0.524517132
## 661 0.991173332 4.693996e-01 0.521773775
## 519 0.991188917 4.698437e-01 0.521345245
## 706 0.530055429 8.807546e-03 0.521247883
## 144 0.529377422 8.783818e-03 0.520593605
## 134 0.991287670 4.726769e-01 0.518610721
## 265 0.991293729 4.728519e-01 0.518441843
## 385 0.991368374 4.750175e-01 0.516350899
## 471 0.524802214 8.625441e-03 0.516176773
## 237 0.523622302 8.585082e-03 0.515037220
## 869 0.520319302 8.473141e-03 0.511846161
## 815 0.518683046 8.418248e-03 0.510264798
## 252 0.991593386 4.816651e-01 0.509928241
## 614 0.991598207 4.818096e-01 0.509788626
## 566 0.991609723 4.821549e-01 0.509454790
## 124 0.516809063 8.355828e-03 0.508453235
## 246 0.514818011 8.290028e-03 0.506527982
## 12 0.514542481 8.280965e-03 0.506261517
## 635 0.512607827 8.217607e-03 0.504390221
## 14 0.511976167 8.197028e-03 0.503779140
## 906 0.510996323 8.165208e-03 0.502831115
## 468 0.991944108 4.923997e-01 0.499544440
## 603 0.991957539 4.928201e-01 0.499137417
## 245 0.506536704 8.021958e-03 0.498514745
## 658 0.506045039 8.006321e-03 0.498038718
## 573 0.503407748 7.922963e-03 0.495484785
## 337 0.502569589 7.896653e-03 0.494672936
## 513 0.502206239 7.885275e-03 0.494320965
## 952 0.501540105 7.864457e-03 0.493675648
## 848 0.500656452 7.836925e-03 0.492819526
## 708 0.992267663 5.027287e-01 0.489538936
## 900 0.494772063 7.656007e-03 0.487116056
## 682 0.493050558 7.603861e-03 0.485446697
## 756 0.492986056 7.601914e-03 0.485384143
## 240 0.992410147 5.074138e-01 0.484996317
## 127 0.492525529 7.588026e-03 0.484937503
## 386 0.492522430 7.587933e-03 0.484934497
## 171 0.491643909 7.561510e-03 0.484082400
## 355 0.992441813 5.084668e-01 0.483975041
## 650 0.992451231 5.087808e-01 0.483670459
## 951 0.490876080 7.538489e-03 0.483337590
## 662 0.992462715 5.091641e-01 0.483298581
## 752 0.489979377 7.511691e-03 0.482467686
## 344 0.992505511 5.105979e-01 0.481907630
## 453 0.992629145 5.147849e-01 0.477844273
## 710 0.992646757 5.153868e-01 0.477259954
## 530 0.482307914 7.286168e-03 0.475021745
## 928 0.481147696 7.252633e-03 0.473895063
## 655 0.480568305 7.235941e-03 0.473332364
## 699 0.992764765 5.194562e-01 0.473308546
## 285 0.477483561 7.147685e-03 0.470335876
## 42 0.992869109 5.231075e-01 0.469761595
## 560 0.476487979 7.119419e-03 0.469368560
## 163 0.474668734 7.068042e-03 0.467600691
## 937 0.992941537 5.256718e-01 0.467269704
## 178 0.992945007 5.257953e-01 0.467149681
## 665 0.992948545 5.259213e-01 0.467027252
## 135 0.473987062 7.048881e-03 0.466938181
## 546 0.992971486 5.267394e-01 0.466232048
## 933 0.993007130 5.280157e-01 0.464991476
## 737 0.471183973 6.970602e-03 0.464213371
## 984 0.993078015 5.305718e-01 0.462506233
## 913 0.993125022 5.322804e-01 0.460844648
## 932 0.467094371 6.857850e-03 0.460236521
## 729 0.466554690 6.843099e-03 0.459711592
## 404 0.465334195 6.809845e-03 0.458524350
## 364 0.465218490 6.806700e-03 0.458411789
## 812 0.463628041 6.763609e-03 0.456864431
## 333 0.462844903 6.742484e-03 0.456102419
## 872 0.462259776 6.726739e-03 0.455533037
## 521 0.993309456 5.390902e-01 0.454219290
## 68 0.460124988 6.669582e-03 0.453455406
## 771 0.459354487 6.649061e-03 0.452705426
## 441 0.459123721 6.642927e-03 0.452480795
## 980 0.457348086 6.595895e-03 0.450752191
## 791 0.456114714 6.563405e-03 0.449551309
## 34 0.993437211 5.439087e-01 0.449528507
## 831 0.454497191 6.521015e-03 0.447976176
## 303 0.993483623 5.456804e-01 0.447803270
## 664 0.450922554 6.428207e-03 0.444494346
## 680 0.993615427 5.507741e-01 0.442841334
## 841 0.447152263 6.331603e-03 0.440820660
## 893 0.445572075 6.291499e-03 0.439280576
## 396 0.445309907 6.284867e-03 0.439025040
## 205 0.993717307 5.547761e-01 0.438941186
## 435 0.438095567 6.104770e-03 0.431990798
## 400 0.993918906 5.628666e-01 0.431052297
## 583 0.993919436 5.628882e-01 0.431031241
## 409 0.993943124 5.638542e-01 0.430088941
## 313 0.435583844 6.043133e-03 0.429540711
## 891 0.435141021 6.032322e-03 0.429108699
## 174 0.434846981 6.025153e-03 0.428821828
## 289 0.430394439 5.917485e-03 0.424476954
## 226 0.430341588 5.916217e-03 0.424425371
## 693 0.428826551 5.879965e-03 0.422946585
## 778 0.427802877 5.855578e-03 0.421947299
## 861 0.424645683 5.780903e-03 0.418864780
## 30 0.421215040 5.700672e-03 0.415514368
## 322 0.420951884 5.694556e-03 0.415257327
## 645 0.994362521 5.815156e-01 0.412846914
## 79 0.417828791 5.622394e-03 0.412206398
## 828 0.417241411 5.608907e-03 0.411632505
## 663 0.994449118 5.852991e-01 0.409150040
## 922 0.410520742 5.456481e-03 0.405064262
## 787 0.994575504 5.909090e-01 0.403666546
## 21 0.409090843 5.424492e-03 0.403666351
## 271 0.408083539 5.402049e-03 0.402681490
## 506 0.407318294 5.385049e-03 0.401933245
## 214 0.406991538 5.377803e-03 0.401613734
## 125 0.994625432 5.931545e-01 0.401470953
## 261 0.994647809 5.941664e-01 0.400481437
## 482 0.403780828 5.307024e-03 0.398473804
## 184 0.994761401 5.993559e-01 0.395405478
## 688 0.994782119 6.003121e-01 0.394470028
## 859 0.994817973 6.019739e-01 0.392844029
## 104 0.994978831 6.095431e-01 0.385435760
## 191 0.387996980 4.969736e-03 0.383027244
## 597 0.386689277 4.942561e-03 0.381746716
## 986 0.385309010 4.914001e-03 0.380395009
## 467 0.384527117 4.897878e-03 0.379629239
## 697 0.382652931 4.859397e-03 0.377793534
## 495 0.995163390 6.184622e-01 0.376701165
## 283 0.380349180 4.812411e-03 0.375536770
## 968 0.379559333 4.796381e-03 0.374762952
## 315 0.379323811 4.791608e-03 0.374532203
## 136 0.379275963 4.790639e-03 0.374485324
## 541 0.995237880 6.221355e-01 0.373102396
## 476 0.376726814 4.739224e-03 0.371987590
## 654 0.376530047 4.735273e-03 0.371794774
## 978 0.376489627 4.734461e-03 0.371755166
## 131 0.375559573 4.715820e-03 0.370843753
## 300 0.374867551 4.701985e-03 0.370165566
## 623 0.995316220 6.260453e-01 0.369270870
## 887 0.373722014 4.679149e-03 0.369042864
## 116 0.372450703 4.653903e-03 0.367796800
## 627 0.372070603 4.646375e-03 0.367424229
## 515 0.369763845 4.600877e-03 0.365162968
## 561 0.368856188 4.583065e-03 0.364273122
## 957 0.368020335 4.566707e-03 0.363453628
## 17 0.995454124 6.330472e-01 0.362406886
## 201 0.365703641 4.521590e-03 0.361182051
## 328 0.363977049 4.488176e-03 0.359488873
## 472 0.995519969 6.364453e-01 0.359074707
## 743 0.363360943 4.476296e-03 0.358884646
## 40 0.363006744 4.469477e-03 0.358537267
## 434 0.360172063 4.415169e-03 0.355756893
## 262 0.995596848 6.404586e-01 0.355138273
## 554 0.354792909 4.313410e-03 0.350479499
## 408 0.995713659 6.466531e-01 0.349060608
## 959 0.352779729 4.275756e-03 0.348503974
## 990 0.995728046 6.474242e-01 0.348303826
## 685 0.348030288 4.187833e-03 0.343842455
## 814 0.995865031 6.548586e-01 0.341006407
## 220 0.344870253 4.130031e-03 0.340740222
## 69 0.343128246 4.098402e-03 0.339029844
## 747 0.995969899 6.606649e-01 0.335304963
## 100 0.995972921 6.608338e-01 0.335139161
## 751 0.336658497 3.982371e-03 0.332676126
## 852 0.996029440 6.640071e-01 0.332022301
## 668 0.996051097 6.652311e-01 0.330820030
## 805 0.334640470 3.946635e-03 0.330693835
## 212 0.996077768 6.667445e-01 0.329333230
## 88 0.996108461 6.684947e-01 0.327613742
## 442 0.327780199 3.826737e-03 0.323953463
## 64 0.996202231 6.738983e-01 0.322303975
## 987 0.996213348 6.745446e-01 0.321668730
## 709 0.324533790 3.770837e-03 0.320762953
## 644 0.323568178 3.754313e-03 0.319813865
## 870 0.322540735 3.736782e-03 0.318803953
## 279 0.321580961 3.720453e-03 0.317860508
## 444 0.319456839 3.684476e-03 0.315772363
## 402 0.996332329 6.815398e-01 0.314792492
## 282 0.996369872 6.837769e-01 0.312592955
## 944 0.315150699 3.612218e-03 0.311538481
## 970 0.312764920 3.572569e-03 0.309192351
## 892 0.312115175 3.561819e-03 0.308553357
## 864 0.306881591 3.475950e-03 0.303405641
## 535 0.305101980 3.447042e-03 0.301654938
## 86 0.996558367 6.952320e-01 0.301326344
## 234 0.996579915 6.965657e-01 0.300014173
## 218 0.301673373 3.391760e-03 0.298281613
## 702 0.996688250 7.033485e-01 0.293339748
## 91 0.293387930 3.260357e-03 0.290127573
## 146 0.996743723 7.068724e-01 0.289871292
## 67 0.996771621 7.086579e-01 0.288113712
## 785 0.996811725 7.112403e-01 0.285571465
## 518 0.288305698 3.181253e-03 0.285124445
## 76 0.996847209 7.135407e-01 0.283306544
## 674 0.285516259 3.138309e-03 0.282377950
## 549 0.996892119 7.164733e-01 0.280418807
## 242 0.996926744 7.187507e-01 0.278176046
## 574 0.276157326 2.996617e-03 0.273160709
## 542 0.275533770 2.987306e-03 0.272546464
## 954 0.997018200 7.248354e-01 0.272182801
## 492 0.273810637 2.961656e-03 0.270848982
## 543 0.997067776 7.281766e-01 0.268891219
## 351 0.269376800 2.896205e-03 0.266480594
## 399 0.997105911 7.307675e-01 0.266338424
## 228 0.267325731 2.866194e-03 0.264459537
## 639 0.266122713 2.848668e-03 0.263274045
## 547 0.264587694 2.826388e-03 0.261761306
## 494 0.263817363 2.815242e-03 0.261002121
## 433 0.263361659 2.808659e-03 0.260553000
## 46 0.261465386 2.781352e-03 0.258684034
## 376 0.261393325 2.780317e-03 0.258613008
## 93 0.261152616 2.776862e-03 0.258375754
## 172 0.997246102 7.404509e-01 0.256795161
## 836 0.258665421 2.741285e-03 0.255924136
## 196 0.256401907 2.709113e-03 0.253692794
## 777 0.254187940 2.677832e-03 0.251510108
## 177 0.253998667 2.675166e-03 0.251323501
## 762 0.252996504 2.661074e-03 0.250335430
## 675 0.252406833 2.652800e-03 0.249754033
## 804 0.245277898 2.553778e-03 0.242724120
## 149 0.997491714 7.580426e-01 0.239449157
## 643 0.236882899 2.439518e-03 0.234443381
## 796 0.236058887 2.428437e-03 0.233630450
## 961 0.997586839 7.650799e-01 0.232506933
## 133 0.234050359 2.401525e-03 0.231648833
## 624 0.997599896 7.660560e-01 0.231543929
## 58 0.233850320 2.398853e-03 0.231451468
## 705 0.233083749 2.388624e-03 0.230695126
## 109 0.231667670 2.369781e-03 0.229297889
## 882 0.997633471 7.685773e-01 0.229056207
## 888 0.229700329 2.343717e-03 0.227356612
## 432 0.227018120 2.308393e-03 0.224709727
## 646 0.226803503 2.305577e-03 0.224497926
## 55 0.997702114 7.737833e-01 0.223918802
## 28 0.225682184 2.290890e-03 0.223391294
## 507 0.225248325 2.285218e-03 0.222963107
## 758 0.997720643 7.752006e-01 0.222520065
## 833 0.222773971 2.252993e-03 0.220520978
## 305 0.221810162 2.240495e-03 0.219569667
## 837 0.997761710 7.783602e-01 0.219401547
## 602 0.997769810 7.789864e-01 0.218783446
## 233 0.220565558 2.224402e-03 0.218341156
## 50 0.220244100 2.220254e-03 0.218023847
## 878 0.220082721 2.218172e-03 0.217864549
## 8 0.219731831 2.213650e-03 0.217518182
## 550 0.219273930 2.207754e-03 0.217066176
## 54 0.216855860 2.176734e-03 0.214679126
## 359 0.216827001 2.176365e-03 0.214650636
## 612 0.216609292 2.173582e-03 0.214435710
## 795 0.997862661 7.862365e-01 0.211626200
## 25 0.212106697 2.116358e-03 0.209990338
## 898 0.209390882 2.082155e-03 0.207308727
## 457 0.997932701 7.917945e-01 0.206138194
## 883 0.997933541 7.918616e-01 0.206071952
## 329 0.997979493 7.955512e-01 0.202428302
## 95 0.203928341 2.014059e-03 0.201914282
## 497 0.201846004 1.988343e-03 0.199857661
## 41 0.200586896 1.972858e-03 0.198614037
## 718 0.200215619 1.968302e-03 0.198247317
## 217 0.198883564 1.951987e-03 0.196931576
## 508 0.998051657 8.014145e-01 0.196637112
## 740 0.998061275 8.022025e-01 0.195858762
## 505 0.193578360 1.887541e-03 0.191690819
## 158 0.998143102 8.089687e-01 0.189174442
## 772 0.188858162 1.830903e-03 0.187027259
## 536 0.187395475 1.813484e-03 0.185581991
## 826 0.178914325 1.713697e-03 0.177200628
## 115 0.998291498 8.215320e-01 0.176759488
## 321 0.177752116 1.700181e-03 0.176051934
## 414 0.171956908 1.633349e-03 0.170323559
## 9 0.170895572 1.621210e-03 0.169274363
## 703 0.170158227 1.612794e-03 0.168545433
## 715 0.170036554 1.611407e-03 0.168425147
## 22 0.169085538 1.600578e-03 0.167484961
## 628 0.998409312 8.317848e-01 0.166624478
## 1000 0.161735559 1.517704e-03 0.160217855
## 923 0.160361337 1.502369e-03 0.158858968
## 811 0.998503714 8.401850e-01 0.158318693
## 746 0.998509388 8.406952e-01 0.157814158
## 609 0.157734385 1.473192e-03 0.156261193
## 209 0.998559357 8.452154e-01 0.153343945
## 53 0.151589806 1.405644e-03 0.150184161
## 696 0.998645660 8.531369e-01 0.145508782
## 354 0.998665537 8.549822e-01 0.143683347
## 753 0.144297413 1.326726e-03 0.142970686
## 185 0.139682702 1.277471e-03 0.138405231
## 175 0.998724751 8.605267e-01 0.138198079
## 199 0.137872295 1.258290e-03 0.136614005
## 637 0.998756250 8.635051e-01 0.135251109
## 689 0.131789818 1.194429e-03 0.130595390
## 412 0.130881137 1.184964e-03 0.129696173
## 629 0.129714361 1.172840e-03 0.128541521
## 936 0.129310979 1.168656e-03 0.128142323
## 979 0.128622279 1.161522e-03 0.127460758
## 782 0.998839338 8.714608e-01 0.127378553
## 839 0.126576182 1.140391e-03 0.125435791
## 335 0.126534206 1.139958e-03 0.125394247
## 253 0.125229592 1.126537e-03 0.124103055
## 750 0.998876314 8.750481e-01 0.123828225
## 167 0.123767362 1.111542e-03 0.122655819
## 765 0.122978550 1.103474e-03 0.121875076
## 764 0.122726792 1.100901e-03 0.121625891
## 74 0.122509221 1.098680e-03 0.121410542
## 478 0.121880805 1.092269e-03 0.120788536
## 238 0.119974542 1.072877e-03 0.118901665
## 440 0.998927656 8.800779e-01 0.118849733
## 802 0.119166671 1.064684e-03 0.118101987
## 377 0.998970743 8.843436e-01 0.114627156
## 766 0.115579518 1.028484e-03 0.114551034
## 71 0.113748490 1.010118e-03 0.112738373
## 838 0.998989896 8.862529e-01 0.112736978
## 963 0.113730505 1.009938e-03 0.112720567
## 895 0.111097100 9.836560e-04 0.110113444
## 352 0.999034636 8.907449e-01 0.108289702
## 61 0.108165451 9.545788e-04 0.107210872
## 375 0.999045876 8.918805e-01 0.107165371
## 537 0.107878821 9.517460e-04 0.106927075
## 874 0.107294332 9.459752e-04 0.106348357
## 969 0.105113702 9.245109e-04 0.104189191
## 553 0.104332684 9.168485e-04 0.103415835
## 297 0.103816534 9.117918e-04 0.102904742
## 296 0.102908000 9.029051e-04 0.102005095
## 356 0.098000409 8.552090e-04 0.097145200
## 254 0.999158969 9.034687e-01 0.095690269
## 659 0.096374504 8.395203e-04 0.095534983
## 607 0.094656542 8.230040e-04 0.093833537
## 213 0.093905462 8.158028e-04 0.093089659
## 868 0.092221350 7.996986e-04 0.091421651
## 581 0.092168149 7.991908e-04 0.091368958
## 381 0.091825774 7.959245e-04 0.091029850
## 576 0.091042316 7.884594e-04 0.090253857
## 269 0.091022679 7.882725e-04 0.090234406
## 197 0.999212163 9.090230e-01 0.090189115
## 736 0.089221625 7.711603e-04 0.088450465
## 179 0.088655648 7.657967e-04 0.087889851
## 316 0.999234811 9.114085e-01 0.087826299
## 66 0.088351313 7.629153e-04 0.087588398
## 211 0.086103408 7.416915e-04 0.085361717
## 929 0.999273646 9.155280e-01 0.083745692
## 295 0.999286039 9.168503e-01 0.082435750
## 606 0.082412766 7.070697e-04 0.081705696
## 281 0.999293873 9.176881e-01 0.081605738
## 996 0.081575982 6.992582e-04 0.080876724
## 651 0.080768392 6.917326e-04 0.080076659
## 317 0.079399712 6.790083e-04 0.078720704
## 647 0.999323354 9.208549e-01 0.078468484
## 571 0.078356846 6.693382e-04 0.077687508
## 776 0.077748595 6.637082e-04 0.077084886
## 975 0.075590588 6.437926e-04 0.074946795
## 210 0.075498994 6.429493e-04 0.074856044
## 113 0.075492065 6.428855e-04 0.074849180
## 783 0.073918207 6.284219e-04 0.073289785
## 405 0.999377643 9.267433e-01 0.072634322
## 860 0.070908128 6.008950e-04 0.070307233
## 480 0.070127297 5.937832e-04 0.069533514
## 934 0.999414555 9.307898e-01 0.068624735
## 130 0.067825356 5.728860e-04 0.067252470
## 971 0.999434884 9.330334e-01 0.066401450
## 927 0.066144658 5.576930e-04 0.065586965
## 273 0.065856664 5.550950e-04 0.065301569
## 807 0.065122257 5.484773e-04 0.064573780
## 565 0.064517521 5.430357e-04 0.063974486
## 186 0.061242470 5.136868e-04 0.060728783
## 918 0.999488394 9.389905e-01 0.060497851
## 2 0.060835480 5.100538e-04 0.060325427
## 733 0.999507646 9.411524e-01 0.058355274
## 754 0.999509058 9.413113e-01 0.058197802
## 850 0.057204726 4.777814e-04 0.056726945
## 626 0.054460655 4.535535e-04 0.054007102
## 714 0.053566642 4.456902e-04 0.053120952
## 189 0.052277301 4.343756e-04 0.051842925
## 45 0.999604314 9.521580e-01 0.047446335
## 949 0.044789600 3.692666e-04 0.044420333
## 818 0.999635589 9.557735e-01 0.043862094
## 881 0.041517497 3.411309e-04 0.041176366
## 741 0.041410228 3.402118e-04 0.041070016
## 181 0.999681885 9.611758e-01 0.038506081
## 111 0.999684757 9.615129e-01 0.038171873
## 653 0.999714285 9.649927e-01 0.034721550
## 677 0.034528046 2.816646e-04 0.034246382
## 858 0.999725157 9.662802e-01 0.033444975
## 401 0.999728375 9.666619e-01 0.033066466
## 19 0.032987016 2.686682e-04 0.032718348
## 224 0.032589035 2.653185e-04 0.032323717
## 813 0.030801424 2.503062e-04 0.030551118
## 121 0.999755291 9.698667e-01 0.029888597
## 473 0.029485490 2.392901e-04 0.029246200
## 243 0.028284295 2.292603e-04 0.028055034
## 155 0.025318253 2.045994e-04 0.025113653
## 551 0.024946191 2.015164e-04 0.024744675
## 700 0.999804301 9.757567e-01 0.024047625
## 325 0.023761360 1.917142e-04 0.023569646
## 38 0.999821059 9.777869e-01 0.022034170
## 919 0.018511836 1.485669e-04 0.018363269
## 686 0.999869178 9.836634e-01 0.016205791
## 437 0.015984383 1.279560e-04 0.015856427
## 871 0.999872139 9.840273e-01 0.015844829
## 998 0.015526673 1.242347e-04 0.015402438
## 487 0.015136788 1.210675e-04 0.015015720
## 972 0.015125884 1.209789e-04 0.015004905
## 725 0.999879504 9.849336e-01 0.014945900
## 102 0.014979115 1.197873e-04 0.014859327
## 123 0.014532486 1.161634e-04 0.014416322
## 118 0.014527540 1.161233e-04 0.014411416
## 676 0.013184480 1.052455e-04 0.013079234
## 378 0.012433105 9.917273e-05 0.012333932
## 151 0.012358979 9.857413e-05 0.012260405
## 941 0.011828070 9.428936e-05 0.011733781
## 821 0.005799454 4.595317e-05 0.005753501
## 438 0.005776271 4.576841e-05 0.005730503
## 461 0.002873117 2.269945e-05 0.002850418
Further use of the results:
Evaluate Uplift: Analyze the distribution of uplift scores and assess the overall lift generated by the treatment compared to the control. Positive uplift scores indicate that the treatment is effective in increasing the probability of the desired outcome, while negative uplift scores indicate the opposite.
Segmentation (Optional): Optionally, segment the data based on uplift scores to identify subsets of individuals or segments with the highest uplift. This can help tailor marketing strategies or interventions to maximize the impact on those who are most responsive to the treatment.
Validate and Iterate: Validate the uplift scores and iterate on the modeling process as needed. You may need to refine the model, feature engineering, or treatment strategies to improve uplift.
Deploy and Monitor: Once satisfied with the uplift model’s performance, deploy it in production and monitor its effectiveness over time. Continuously evaluate and refine the model based on real-world feedback and performance metrics.
Cluster analysis is an unsupervised learning technique used to group similar objects or data points into clusters based on their characteristics or features. The goal of cluster analysis is to find natural groupings in data without any prior knowledge of group memberships.
Key Concepts:
Clusters: Groups of data points that are similar to each other within the cluster but different from data points in other clusters.
Similarity or Distance: Measures used to determine the similarity between data points, such as Euclidean distance, Manhattan distance, or cosine similarity.
Centroids: Representative points within each cluster that summarize the characteristics of the data points in the cluster. The centroid is often used as a reference point for assigning data points to clusters.
Clustering Algorithm: A method or procedure used to partition data into clusters based on similarity or distance measures. Common clustering algorithms include K-means and hierarchical clustering.
Applications of Cluster Analysis:
Hierarchical clustering is a method of cluster analysis that builds a hierarchy of clusters. In hierarchical clustering, data points are grouped together based on their similarity or distance. The result is a dendrogram, a tree-like diagram that shows the arrangement of the clusters.
Key Concepts:
Agglomerative vs. Divisive:
Distance between observations:
Linkage Methods:
Dendrogram:
A nice reference for agglomerative clustering is: https://www.datanovia.com/en/lessons/agglomerative-hierarchical-clustering/
A nice reference for divisive clustering is: https://www.datanovia.com/en/lessons/divisive-hierarchical-clustering/
Example 1. Hierarchical clustering using the built-in iris dataset.
# Perform hierarchical clustering
dist_matrix <- dist(iris[, 1:4], method = "euclidean")
hc_result <- hclust(dist_matrix, method = "complete")
# Plot the dendrogram
plot(hc_result, main = "Hierarchical Clustering Dendrogram", xlab = "Samples", ylab = "Distance")
If we use 5 clusters, what plants are in which clusters?
clusters = cutree(hc_result, k = 3)
clusters
## [1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
## [38] 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 3 2 3 2 3 2 3 3 3 3 2 3 2 3 3 2 3 2 3 2 2
## [75] 2 2 2 2 2 3 3 3 3 2 3 2 2 2 3 3 3 2 3 3 3 3 3 2 3 3 2 2 2 2 2 2 3 2 2 2 2
## [112] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
## [149] 2 2
The results show that the first 50 plants are all in one cluster (cluster 1), the last 50 plants are all in one cluster (cluster 2, except the 107th plant), the middle 50 plants are either in 2nd or 3rd cluster.
The following is a heatmap of individual variables:
D = iris[, 1:4]
row.names(D) <- paste(clusters, ": ", row.names(D), sep = "")
heatmap(as.matrix(D), Colv = NA, hclustfun = hclust)
Dark cells denote higher values within a column.
DIY
Can you use the data from https://hdr.undp.org/data-center/human-development-index#/indicies/HDI to cluster countries by Life expectancy at birth, Expected years of schooling, Mean years of schooling, and Gross national income (GNI) per capita? If you use the Human Development Index (HDI) alone to cluster the countries, how do the results differ?
In K-means clustering, observations are grouped into clusters based on all variables, with the goal of minimizing the variance within each cluster.
Key Concepts:
Steps in K-Means Clustering:
Choosing the Number of Clusters (K):
Example 1: K-Means Clustering with iris data
# Perform K-means clustering
set.seed(123)
kmeans_result <- kmeans(iris[, 1:4], centers = 3) # Assuming 3 clusters
# Plot the clusters
plot(iris[, c(1, 3)], col = kmeans_result$cluster,
main = "K-Means Clustering of Iris Dataset",
xlab = "Sepal Length", ylab = "Petal Length")
points(kmeans_result$centers[, c(1, 3)], col = 1:3, pch = 8, cex = 2)
Which observation is in which cluster?
kmeans_result$cluster
## [1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
## [38] 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 3 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
## [75] 2 2 2 3 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 2 3 3 3 3 2 3 3 3 3
## [112] 3 3 2 2 3 3 3 3 2 3 2 3 2 3 3 2 2 3 3 3 3 3 2 3 3 3 3 2 3 3 3 2 3 3 3 2 3
## [149] 3 2
The results show that the first 50 plants are all in one cluster (cluster 1), the middle 50 plants are all in one cluster (cluster 3, except the 78th plant), the last 50 plants are either in 2nd or 3rd cluster.
How to select \(K\) using the elbow method? We need to use the fviz_nbclust() function from the “factoextra” package. We demonstrate the method using iris data.
library(factoextra)
## Welcome! Want to learn more? See two factoextra-related books at https://goo.gl/ve3WBa
fviz_nbclust(iris[, 1:4], kmeans, method = "wss") +
labs(title = "Elbow Plot for K-Means Clustering",
x = "Number of Clusters (k)",
y = "Total Within-Cluster Sum of Squares")
The elbow occurs at k = 2 or 3, so we can choose 2 or 3 clusters to create.
A time series is a sequence of data points or observations collected, recorded, or measured at successive points in time. These data points are typically collected at regular intervals, such as every second, minute, hour, day, week, month, quarter, or year. Time series data is used in various fields, including economics, finance, engineering, environmental science, and many others, to analyze trends, patterns, and behaviors over time.
As with cross-sectional data (data collected at one point in time), modeling time series data is done for either descriptive or predictive purposes. In descriptive modeling, (the ordinary time series analysis), a time series is modeled to determine its components in terms of seasonal patterns, trends, relation to external factors, etc. These can then be used for decision making and policy formulation. In contrast, time series forecasting uses the information in the time series (and perhaps other information) to forecast future values of that series.
A time series can be dissected into 4 components:
To identify the components of a time series, the first step is to examine a time plot (values over time) with calendar date on the horizontal axis.
Example 1. Ridership on Amtrak Trains
Amtrak, a US railway company, routinely collects data on ridership. We will use the monthly data between 1/1991 and 3/2004. Data are publicly available at <www.forecastingprincples.com>. To get the data, click on Data. Under T-Competition Data, click “time-series data” (the direct URL is <www.forecastingprincples.com/files/MHcomp1.xls> as of 2015). This file contains many time series. In the Monthly worksheet, column AI contains series M034.
library(forecast)
library(ggplot2)
Amtrak.data <- mlba::Amtrak
ridership.ts <- ts(Amtrak.data$Ridership, start = c(1991, 1), end = c(2004, 3), freq = 12 )
autoplot(ridership.ts, xlab = "Time", ylab = "Ridership (in 000s)") +
ylim(1300, 2300) +
labs(title = "Monthly Ridership on Amtrak Trains (in Thousands)") +
theme(plot.title = element_text(hjust = 0.5))
The time series plot reveals that
Zooming in to a shorter period (say 3 years) can reveal patterns that are hidden:
library(gridExtra)
##
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
##
## combine
ridership.ts.3yrs <- window(ridership.ts, start = c(1997, 1), end = c(1999, 12))
g1 <- autoplot(ridership.ts.3yrs, xlab = "Time", ylab = "Ridership (in 000s)") +
ylim(1300, 2300)
g2 <- autoplot(ridership.ts, xlab = "Time", ylab = "Ridership (in 000s)") +
ylim(1300, 2300) +
geom_smooth(method = "lm", formula = y~poly(x, 2), se = FALSE)
grid.arrange(g1, g2, nrow = 2)
. ### 10.2 Data Partitioning and Performance Evaluation
As in the case of cross-sectional data, we partition data into a training set and a holdout set. We use the training data to train a model and use the the holdout data to assess the model.
The performance measures are the same as we did for prediction on a quantitative response. One of the most popular metric is the root mean squared error (RMSE).
We will use the naive forecast as a benchmark for comparison. The naive method always uses the most recent data value (ignoring all previous ones) to forecast any future, so it will result in a horizontal line when the forecasts are plotted.
When the time series has seasonality, a seasonal naive forecast can be generated. For example, to forecast the value in March of a future year, just use the most recent available March value. We demonstrate the naive methods below:
# Define the number of test/holdout observations
nTest <- 36
# Calculate the number of training observations
nTrain <- length(ridership.ts) - nTest
# Extract the training time series data
train.ts <- window(ridership.ts,
start = c(1991, 1),
end = c(1991, nTrain))
# Extract the test time series data
test.ts <- window(ridership.ts,
start = c(1991, nTrain + 1),
end = c(1991, nTrain + nTest))
# Perform naive forecasting on the training data
naive.pred <- naive(train.ts, h = nTest)
# Perform seasonal naive forecasting on the training data
snaive.pred <- snaive(train.ts, h = nTest)
# Define colors for plotting
colData <- "blue"
colModel.naive <- "tomato"
colModel.snaive <- "orange"
# Plot the training data, test data, and forecasts
autoplot(train.ts, xlab = "Time", ylab = "Ridership (in 000s)", color = colData) +
autolayer(test.ts, linetype = 3, color = colData) +
autolayer(naive.pred, PI = FALSE, color = colModel.naive, size = 1.75) +
autolayer(snaive.pred, PI = FALSE, color = colModel.snaive, size = 0.75)
Explanation of code:
Naive and seasonal naive forecasting methods are applied to the training data using the naive() and snaive() functions, respectively. These functions generate forecasts for the specified horizon (h = nTest), which is the length of the test set.
The autoplot() function from the forecast package is used to create an initial plot of the training data. Additional layers are added using the autolayer() function to overlay the test data and the forecasts generated by the naive and seasonal naive methods. The PI = FALSE argument suppresses the plotting of prediction intervals, and size = 0.75 adjusts the line thickness for better visualization.
We next assess the performance of the two models (naive and snaive):
# Calculate accuracy measures for naive forecast
accuracy(naive.pred, test.ts)
## ME RMSE MAE MPE MAPE MASE
## Training set 2.45091 168.1470 125.2975 -0.3460027 7.271393 1.518906
## Test set -14.71772 142.7551 115.9234 -1.2749992 6.021396 1.405269
## ACF1 Theil's U
## Training set -0.2472621 NA
## Test set 0.2764480 0.8346967
# Calculate accuracy measures for seasonal naive forecast
accuracy(snaive.pred, test.ts)
## ME RMSE MAE MPE MAPE MASE ACF1
## Training set 13.93991 99.26557 82.49196 0.5850656 4.715251 1.000000 0.6400044
## Test set 54.72961 95.62433 84.09406 2.6527928 4.247656 1.019421 0.6373346
## Theil's U
## Training set NA
## Test set 0.5532435
Generating Future Forecasts:
Once a good model is suggested, we use all the original data to train the model. This final model is then used to forecast future values. There are 3 advantages:
We will use the Amtrak data only.
The model is \[y_t = \beta_0+\beta_1 t + \epsilon\]
where \(y_t\) is the ridership at period \(t\) and \(\epsilon\) is the noise term. We are modeling 3 of the 4 time series components: level (\(\beta_0\)), trend (\(\beta_1\)), and noise.
# Load necessary libraries
library(forecast)
library(ggplot2)
# Load Amtrak ridership data
Amtrak.data <- mlba::Amtrak
# Create time series object
ridership.ts <- ts(Amtrak.data$Ridership, start = c(1991, 1), end = c(2004, 3), freq = 12)
# Fit linear trend model to the time series data
ridership.lm <- tslm(ridership.ts ~ trend)
# Plot original time series data and fitted values from the linear trend model
autoplot(ridership.ts, xlab = "Time", ylab = "Ridership (in 000s)", color = "blue") +
autolayer(ridership.lm$fitted.values, color = "tomato", size = 0.75) +
ylim(1300, 2300)
Explanation of code:
The autolayer() function overlays the fitted values from the linear trend model (ridership.lm$fitted.values) on the plot. The ylim() function sets the y-axis limits to focus on the range of ridership values between 1300 and 2300.
Now, we create training and holdout sets.
# Define the number of test/holdout observations
nTest <- 36
# Calculate the number of training observations
nTrain <- length(ridership.ts) - nTest
# Extract the training time series data
train.ts <- window(ridership.ts,
start = c(1991, 1),
end = c(1991, nTrain))
# Extract the test/holdout time series data
test.ts <- window(ridership.ts,
start = c(1991, nTrain + 1),
end = c(1991, nTrain + nTest))
We then train a trend model and plot the fitted values and forecast errors for training and test data:
# Fit linear trend model to the training data
train.lm <- tslm(train.ts ~ trend)
# Load gridExtra library for arranging plots
library(gridExtra)
# Define colors for plotting
colData <- "blue"
colModel <- "tomato"
# Function to plot forecasted values
plotForecast <- function(model, train.ts, test.ts){
# Generate forecasts
model.pred <- forecast(model, h = length(test.ts), level = 0)
# Plot original training data
g <- autoplot(train.ts,
xlab = "Time",
ylab = "Ridership (in 000s)",
color = colData) +
# Overlay test data
autolayer(test.ts, color = colData) +
# Overlay fitted values from the model
autolayer(model$fitted.values, color = colModel, size = 0.75) +
# Overlay forecasted values
autolayer(model.pred$mean, color = colModel, size = 0.75) +
# Add vertical dashed line to indicate separation between training and test data
geom_vline(xintercept = 2001.167, lty= "dashed")
return(g)
}
# Function to plot forecast errors (residuals)
plotResiduals <- function(model, test.ts){
# Generate forecasts
model.pred <- forecast(model, h = length(test.ts), level = 0)
# Plot residuals
g <- autoplot(model$residuals,
xlab = "Time",
ylab = "Forecast Error",
color = colModel,
size = 0.75) +
# Overlay actual forecast errors
autolayer(test.ts - model.pred$mean, color = colModel) +
# Set y-axis limits based on range of residuals
ylim(range(model$residuals)) +
# Add horizontal dashed line at y = 0 to indicate absence of bias
geom_hline(yintercept = 0, color = "darkgrey") +
# Add vertical dashed line to indicate separation between training and test data
geom_vline(xintercept = 2001.167, lty= "dashed")
return(g)
}
# Generate plots for forecasted values and residuals
g1 <- plotForecast(train.lm, train.ts, test.ts)
g2 <- plotResiduals(train.lm, test.ts)
# Arrange plots in a grid
grid.arrange(g1, g2, nrow = 2)
## Warning: Removed 1 row containing missing values or values outside the scale range
## (`geom_line()`).
We assess the performance of the trend model:
# Summarize the linear trend model
summary(train.lm)
##
## Call:
## tslm(formula = train.ts ~ trend)
##
## Residuals:
## Min 1Q Median 3Q Max
## -411.29 -114.02 16.06 129.28 306.35
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 1750.3595 29.0729 60.206 <2e-16 ***
## trend 0.3514 0.4069 0.864 0.39
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 160.2 on 121 degrees of freedom
## Multiple R-squared: 0.006125, Adjusted R-squared: -0.002089
## F-statistic: 0.7456 on 1 and 121 DF, p-value: 0.3896
# Generate forecasts using the linear trend model for the length of the test data
train.lm.pred <- forecast(train.lm, h = length(test.ts), level = 0)
# Calculate accuracy measures for the forecasted values compared to the actual test data
accuracy(train.lm.pred, test.ts)
## ME RMSE MAE MPE MAPE MASE
## Training set 3.720601e-15 158.9269 129.6778 -0.853984 7.535999 1.572005
## Test set 1.931316e+02 239.4863 209.4371 9.209919 10.147732 2.538879
## ACF1 Theil's U
## Training set 0.4372087 NA
## Test set 0.2734545 1.358369
We transform the data on the log-scale and then train a model:
\[log(y_t)=\beta_0 +\beta_1 t + \epsilon\] which is equivalent to train an exponential model:
\[y_t=ce^{\beta_1 t + \epsilon}\] where \(c = e^{\beta_0}\) and \(\epsilon\) is the error term.
# Fit a trend model with exponential trend component as indicated by lambda = 0
train.lm.expo.trend <- tslm(train.ts ~ trend, lambda = 0)
# Generate forecasts using the exponential trend model
train.lm.expo.trend.pred <- forecast(train.lm.expo.trend, h = nTest, level = 0)
# Fit a trend model with linear trend component
train.lm.linear.trend <- tslm(train.ts ~ trend, lambda = 1)
# Generate forecasts using the linear trend model
train.lm.linear.trend.pred <- forecast(train.lm.linear.trend, h = nTest, level = 0)
# Plot forecasted values using exponential trend model
plotForecast(train.lm.expo.trend, train.ts, test.ts) +
# Overlay fitted values from the linear trend model
autolayer(train.lm.linear.trend$fitted.values, color = "green", size = 0.75) +
# Overlay forecasted values from the linear trend model
autolayer(train.lm.linear.trend.pred, color = "green", size = 0.75)
From the plot of the forecast errors, the shape is similar to the original time series, meaning that the seasonality is not modeled by the exponential trend model. We need to account for seasonality later.
The time series might show a polynomial (usually quadratic or cubic) trend, so we can train
where \(\epsilon\) is the error term.
# Train a linear regression model with polynomial trend terms
train.lm.poly.trend <- tslm(train.ts ~ trend + I(trend^2))
# Summarize the trained model
summary(train.lm.poly.trend)
##
## Call:
## tslm(formula = train.ts ~ trend + I(trend^2))
##
## Residuals:
## Min 1Q Median 3Q Max
## -344.79 -101.86 40.89 98.54 279.81
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 1888.88401 40.91521 46.166 < 2e-16 ***
## trend -6.29780 1.52327 -4.134 6.63e-05 ***
## I(trend^2) 0.05362 0.01190 4.506 1.55e-05 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 148.8 on 120 degrees of freedom
## Multiple R-squared: 0.1499, Adjusted R-squared: 0.1358
## F-statistic: 10.58 on 2 and 120 DF, p-value: 5.844e-05
# Generate plots for forecast and residuals
g1 <- plotForecast(train.lm.poly.trend, train.ts, test.ts) # Plot forecast
g2 <- plotResiduals(train.lm.poly.trend, test.ts) # Plot residuals
# Arrange the plots in a grid layout
grid.arrange(g1, g2, nrow = 2)
In the expression I(trend^2), it means that the term trend^2 should be treated as is without any further transformation.
From the plot of the forecast errors, the shape is similar to the original time series, meaning that the seasonality is not modeled by the exponential trend model. We need to account for seasonality later.
# Train model with additive seasonality by setting lambda = 1 (default)
train.lm.season <- tslm(train.ts ~ season, lambda = 1)
summary(train.lm.season)
##
## Call:
## tslm(formula = train.ts ~ season, lambda = 1)
##
## Residuals:
## Min 1Q Median 3Q Max
## -276.165 -52.934 5.868 54.544 215.081
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 1572.97 30.58 51.442 < 2e-16 ***
## season2 -42.93 43.24 -0.993 0.3230
## season3 260.77 43.24 6.030 2.19e-08 ***
## season4 245.09 44.31 5.531 2.14e-07 ***
## season5 278.22 44.31 6.279 6.81e-09 ***
## season6 233.46 44.31 5.269 6.82e-07 ***
## season7 345.33 44.31 7.793 3.79e-12 ***
## season8 396.66 44.31 8.952 9.19e-15 ***
## season9 75.76 44.31 1.710 0.0901 .
## season10 200.61 44.31 4.527 1.51e-05 ***
## season11 192.36 44.31 4.341 3.14e-05 ***
## season12 230.42 44.31 5.200 9.18e-07 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 101.4 on 111 degrees of freedom
## Multiple R-squared: 0.6348, Adjusted R-squared: 0.5986
## F-statistic: 17.54 on 11 and 111 DF, p-value: < 2.2e-16
Explanation of results:
# Train model with multiplicative seasonality by setting lambda = 0
# This is equivalent to train an additive seasonality model with log-transformed data
train.lm.season <- tslm(train.ts ~ season, lambda = 0)
summary(train.lm.season)
##
## Call:
## tslm(formula = train.ts ~ season, lambda = 0)
##
## Residuals:
## Min 1Q Median 3Q Max
## -0.160824 -0.029798 0.005386 0.032081 0.116843
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 7.35909 0.01765 416.839 < 2e-16 ***
## season2 -0.02764 0.02497 -1.107 0.2706
## season3 0.15326 0.02497 6.138 1.32e-08 ***
## season4 0.14482 0.02558 5.661 1.20e-07 ***
## season5 0.16388 0.02558 6.405 3.73e-09 ***
## season6 0.13897 0.02558 5.432 3.33e-07 ***
## season7 0.19994 0.02558 7.815 3.38e-12 ***
## season8 0.22654 0.02558 8.855 1.53e-14 ***
## season9 0.04825 0.02558 1.886 0.0619 .
## season10 0.12114 0.02558 4.735 6.52e-06 ***
## season11 0.11563 0.02558 4.520 1.56e-05 ***
## season12 0.13760 0.02558 5.379 4.21e-07 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.05855 on 111 degrees of freedom
## Multiple R-squared: 0.6378, Adjusted R-squared: 0.6019
## F-statistic: 17.77 on 11 and 111 DF, p-value: < 2.2e-16
Explanation of results:
The “season” is a categorical variable recognized automatically by the “tslm” function, with the reference category being January.
Under the model with multiplicative seasonality, the coefficient for season8, 0.2265, indicates that the average number of passengers in August is higher by \(100(e^{0.2265}-1)\%\) or 25.43% than the average in January (the reference category). Here is the detail:
We show how a trend and (additive) seasonality can be jointly modeled for Amtrak data. We consider a quadratic trend.
# Train a model with quadratic trend and (additive) seasonality
train.lm.trend.season <- tslm(train.ts ~ trend + I(trend^2) + season, lambda = 1)
summary(train.lm.trend.season)
##
## Call:
## tslm(formula = train.ts ~ trend + I(trend^2) + season, lambda = 1)
##
## Residuals:
## Min 1Q Median 3Q Max
## -213.775 -39.363 9.711 42.422 152.187
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 1.696e+03 2.768e+01 61.282 < 2e-16 ***
## trend -7.156e+00 7.293e-01 -9.812 < 2e-16 ***
## I(trend^2) 6.074e-02 5.698e-03 10.660 < 2e-16 ***
## season2 -4.325e+01 3.024e+01 -1.430 0.15556
## season3 2.600e+02 3.024e+01 8.598 6.60e-14 ***
## season4 2.606e+02 3.102e+01 8.401 1.83e-13 ***
## season5 2.938e+02 3.102e+01 9.471 6.89e-16 ***
## season6 2.490e+02 3.102e+01 8.026 1.26e-12 ***
## season7 3.606e+02 3.102e+01 11.626 < 2e-16 ***
## season8 4.117e+02 3.102e+01 13.270 < 2e-16 ***
## season9 9.032e+01 3.102e+01 2.911 0.00437 **
## season10 2.146e+02 3.102e+01 6.917 3.29e-10 ***
## season11 2.057e+02 3.103e+01 6.629 1.34e-09 ***
## season12 2.429e+02 3.103e+01 7.829 3.44e-12 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 70.92 on 109 degrees of freedom
## Multiple R-squared: 0.8246, Adjusted R-squared: 0.8037
## F-statistic: 39.42 on 13 and 109 DF, p-value: < 2.2e-16
# Generate plots for forecast and residuals
g1 <- plotForecast(train.lm.trend.season, train.ts, test.ts) # Plot forecast
g2 <- plotResiduals(train.lm.trend.season, test.ts) # Plot residuals
# Arrange the plots in a grid layout
grid.arrange(g1, g2, nrow = 2)
## Warning: Removed 1 row containing missing values or values outside the scale range
## (`geom_line()`).
Assess the performance of the model with the holdout data:
## ME RMSE MAE MPE MAPE MASE
## Training set 3.693205e-15 66.76143 51.95091 -0.1525653 3.015509 0.6297693
## Test set -1.261654e+02 153.25066 131.72503 -6.4314945 6.698700 1.5968226
## ACF1 Theil's U
## Training set 0.6040588 NA
## Test set 0.7069291 0.8960679
Since the RMSE for the holdout data is much larger than that for the training data, there might be an overfitting issue.
Many time series data exhibit patterns where values at one time point are correlated to values at previous time points. Such correlation is called autocorrelation.
Autocorrelation is a statistical measure of the degree of similarity between a time series and a lagged version of itself. In other words, it measures how correlated a time series is with its own past values at different lags.
We can use the base acf() function or the Acf() function from the forecast package to calculate autocorrelations.
# Set up plotting parameters for side-by-side layout
par(mfrow = c(1, 2))
# Plot the first ACF plot (years)
acf(ridership.ts, lag.max = 36, main = "Autocorrelation Plot (Years)", xlab = "Lag (in years)")
# Plot the second ACF plot (months)
forecast::Acf(ridership.ts, lag.max = 36, main = "Autocorrelation Plot (Months)", xlab = "Lag (in months)")
# Reset plotting parameters to default
par(mfrow = c(1, 1))
Interpretation of the results:
The above result shows that the lag-1 autocorrelation is about 0.58 and the lag-2 autocorrelation is about 0.38. Only the lag-6 autocorrelation is not significantly different from 0 at the 5% significance level.
ARIMA (AutoRegressive Integrated Moving Average) models are a class of models used for analyzing and forecasting time series data. They are widely used in various fields including finance, economics, and environmental science. ARIMA models are capable of capturing many common time series patterns such as trends, seasonality, and autocorrelation.
An ARIMA model is often denoted as \(ARIMA(p, d, q)\), where:
Special cases:
The ARIMA model can be applied
We will use the second approach, since ARIMA requires the original series to be stationary. Applying an ARIMA model to the residuals of another model, such as a linear regression model, is a common technique used in time series analysis. The rationale behind this approach is to capture any remaining autocorrelation or temporal patterns in the residuals that were not accounted for by the initial model.
Let’s train an \(AR(1)\) model for the Amrak data:
# Fit a linear model with trend, quadratic trend, and seasonal components
train.lm.trend.season <- tslm(train.ts ~ trend + I(trend^2) + season, lambda = 1)
# ACF of the residuals of original model
Acf(train.lm.trend.season$residuals, lag.max = 12)
# Fit an ARIMA model to the residuals of the linear model
train.residual.arima <- Arima(train.lm.trend.season$residuals, order = c(1, 0, 0))
summary(train.residual.arima)
## Series: train.lm.trend.season$residuals
## ARIMA(1,0,0) with non-zero mean
##
## Coefficients:
## ar1 mean
## 0.5998 0.3728
## s.e. 0.0712 11.8408
##
## sigma^2 = 2876: log likelihood = -663.54
## AIC=1333.08 AICc=1333.29 BIC=1341.52
##
## Training set error measures:
## ME RMSE MAE MPE MAPE MASE
## Training set -0.1223159 53.19141 39.8277 103.3885 175.2219 0.5721356
## ACF1
## Training set -0.07509305
# ACF of the residuals of the ARIMA model for original residuals
Acf(train.residual.arima$residuals, lag.max = 12)
# Forecast one step ahead using the ARIMA model for the residuals
train.residual.arima.pred <- forecast(train.residual.arima, h = 1)
train.residual.arima.pred
## Point Forecast Lo 80 Hi 80 Lo 95 Hi 95
## Apr 2001 7.41111 -61.31748 76.1397 -97.70019 112.5224
The second ACF for the ARIMA residuals shows no autocorrelation, while the first ACF for the residuals of the original model does shows strong autocorrelation.
The autoregressive equation for the residual is \(\hat{\epsilon}_t = 0.3728+0.5998 \cdot \hat{\epsilon}_{t-1}\).
To use a forecasted residual, add it to the forecast based on the model for the data. The resulting forecast is an improvement of the original forecast.
The centered moving average (CMA) is calculated by taking the average of data points within a symmetric window centered at each observation. This helps smooth out short-term fluctuations and emphasize the underlying trend. You can use the ma() function from the zoo package to compute the centered moving average. The calculation of CMA involves a window of width that is an odd number.
The trailing moving average, also known as the simple moving average, is computed by taking the average of the most recent data points. You can use the rollmean() function with the align = “right” argument to compute the trailing moving average.
The centered moving average is typically used for visualization as it helps smooth out short-term fluctuations and highlight long-term trends, while the trailing moving average is more suitable for forecasting as it provides a more current estimate of the underlying trend.
library(ggplot2)
library(zoo)
##
## Attaching package: 'zoo'
## The following objects are masked from 'package:base':
##
## as.Date, as.Date.numeric
# Window size for moving averages
window_size <- 12
# Compute centered moving average
centered_ma <- rollmean(ridership.ts, k = window_size, align = "center", fill = NA)
# Compute trailing moving average
trailing_ma <- rollmean(ridership.ts, k = window_size, align = "right", fill = NA)
# Plot original time series and moving averages
autoplot(ridership.ts, series = "Original") +
autolayer(centered_ma, series = "Centered MA", size = 1.5) +
autolayer(trailing_ma, series = "Trailing MA", size = 1.5) +
labs(title = "Centered and Trailing Moving Averages", y = "Ridership", color = "Moving Average") +
theme(legend.position = "bottom")
## Warning: Removed 11 rows containing missing values or values outside the scale range
## (`geom_line()`).
## Removed 11 rows containing missing values or values outside the scale range
## (`geom_line()`).
Let’s show an example that applies trailing moving averages to the residuals of a model trained for the original time series data. We forecast future residuals and use these forecasted residuals to improve the forecast based on the model we have adopted.
## Point Forecast Lo 0 Hi 0
## Apr 2001 2004.271 2004.271 2004.271
## Mar
## 2001 30.78068
## Apr
## 2001 2035.052
Simple exponential smoothing (SES) is a technique used for making short-term forecasts in time series analysis. It assigns exponentially decreasing weights to past observations, with more recent observations receiving higher weights. Here’s how you can implement simple exponential smoothing in R using the ses() function from the forecast package:
# Fit simple exponential smoothing model to the residuals of another model
ses_model <- ets(train.lm.trend.season$residuals, model = "ANN", alpha = 0.2)
# Generate forecasts
ses_forecast <- forecast(ses_model, h = nTest, level = 0.95)
ses_forecast
## Point Forecast Lo 95 Hi 95
## Apr 2001 14.14285 -100.7591 129.0448
## May 2001 14.14285 -103.0346 131.3203
## Jun 2001 14.14285 -105.2667 133.5524
## Jul 2001 14.14285 -107.4579 135.7436
## Aug 2001 14.14285 -109.6103 137.8960
## Sep 2001 14.14285 -111.7259 140.0116
## Oct 2001 14.14285 -113.8065 142.0922
## Nov 2001 14.14285 -115.8538 144.1395
## Dec 2001 14.14285 -117.8694 146.1551
## Jan 2002 14.14285 -119.8546 148.1404
## Feb 2002 14.14285 -121.8109 150.0966
## Mar 2002 14.14285 -123.7394 152.0251
## Apr 2002 14.14285 -125.6414 153.9271
## May 2002 14.14285 -127.5177 155.8034
## Jun 2002 14.14285 -129.3696 157.6553
## Jul 2002 14.14285 -131.1978 159.4836
## Aug 2002 14.14285 -133.0034 161.2891
## Sep 2002 14.14285 -134.7870 163.0727
## Oct 2002 14.14285 -136.5496 164.8353
## Nov 2002 14.14285 -138.2918 166.5775
## Dec 2002 14.14285 -140.0142 168.2999
## Jan 2003 14.14285 -141.7177 170.0034
## Feb 2003 14.14285 -143.4027 171.6884
## Mar 2003 14.14285 -145.0699 173.3556
## Apr 2003 14.14285 -146.7198 175.0055
## May 2003 14.14285 -148.3530 176.6387
## Jun 2003 14.14285 -149.9699 178.2556
## Jul 2003 14.14285 -151.5710 179.8567
## Aug 2003 14.14285 -153.1568 181.4426
## Sep 2003 14.14285 -154.7278 183.0135
## Oct 2003 14.14285 -156.2842 184.5699
## Nov 2003 14.14285 -157.8266 186.1123
## Dec 2003 14.14285 -159.3552 187.6409
## Jan 2004 14.14285 -160.8705 189.1562
## Feb 2004 14.14285 -162.3728 190.6585
## Mar 2004 14.14285 -163.8624 192.1481
All future forecasts are 14.14285 (in thousands of riders) but with different 95% confidence intervals (the further, the wider).
We next report the performance of the model:
## ME RMSE MAE MPE MAPE ACF1 Theil's U
## Test set -140.3082 165.0893 143.4225 -7.144826 7.29532 0.7069291 0.966174
We show the Holt-Winters Exponential smoothing method, which jointly model the level, trend, and seasonality. For a reference, go to https://otexts.com/fpp2/holt-winters.html.
hwin <- ets(train.ts, model = "MAA") # Multiplicative error, Additive trend, Additive seasonality
hwin.forecast <- forecast(hwin, h = nTest, level = 0.95)
autoplot(train.ts, ylab = "Ridership", color = "blue") +
autolayer(test.ts, color = "blue", linetype = 2) +
autolayer(hwin$fitted, color = "red", size = 0.75) +
autolayer(hwin.forecast$mean, color = "red", size = 0.75)
We next report the performance of the model:
## ME RMSE MAE MPE MAPE ACF1 Theil's U
## Test set 33.66844 76.65107 62.16977 1.581526 3.122042 0.6165553 0.4430438
Association rules are a type of rule-based machine learning technique used for discovering interesting relationships or associations among a set of items in large datasets. These rules are particularly popular in the context of market basket analysis, where the goal is to identify patterns of co-occurrence among products that customers frequently buy together.
Collaborative filtering is a technique used in recommendation systems to predict a user’s interests by collecting preferences from many users (collaborating). The underlying idea is that users who agree on certain preferences in the past are likely to agree on future preferences. Collaborative filtering methods can be broadly categorized into two main types: user-based collaborative filtering and item-based collaborative filtering.
Advantages:
Disadvantages and Challenges:
Collaborative filtering requires a user-item interaction matrix (user ratings, likes, purchases) to identify patterns of similarity between users or items. It recommends items to a user based on the preferences of other users who are similar to the target user.
Association rules requires transactional data (market basket data) to discover patterns of item co-occurrence. It recommends items based on the co-occurrence of items in historical