##Introduction Friedman (1991) provided a collection of benchmark datasets for evaluating and comparing regression models. One of the most commonly used examples is synthesized from the following nonlinear equation: y=10sin(πx1x2)+20(x3−0.5) 2+10x4+5x5+N(0,σ2) y = 10 (x_1 x_2) + 20 (x_3 - 0.5). ^2 = 10x_4 + 5x_5 + (0, ^2) y=10sin(πx1x2)+20(x3−0.5) 2+10x4+5x5+N(0,σ2) In this equation, predictors x1x_1x1 to x5x_5x5 are informative, whereas x6x_6x6 to x10x_{10}x10 are noisy variables taken uniformly from [0,1][0, 1][0,1]. This simulation allows us to assess how effectively models identify nonlinearity while ignoring irrelevant factors.
library(mlbench)
library(caret)
## Loading required package: ggplot2
## Loading required package: lattice
library(ggplot2)
library(earth)
## Loading required package: Formula
## Loading required package: plotmo
## Loading required package: plotrix
Starting generating a traning dataset with 200 observations and predictors.
set.seed(200)
trainingData <- mlbench.friedman1(200, sd = 1)
trainingData$x <- data.frame(trainingData$x)
trainingDF <- data.frame(trainingData$x, y = trainingData$y)
This dataset offers a controlled environment for evaluating model performance. The response y is designed to follow a known nonlinear connection, allowing us to evaluate model correctness in a meaningful way.
A test wider dataset is constructed to estimate the models’ real generalization error.
testData <- mlbench.friedman1(5000, sd = 1)
testData$x <- data.frame(testData$x)
testDF <- data.frame(testData$x)
A k-nearest neighbors (k-NN)regression model using caret packages is the start ideal for fitting; Model is trained using boostrap resampling(25 reps) centered and scaled during reprocessing.
set.seed(123)
knnModel <- train(
y ~ .,
data = trainingDF,
method = "knn",
preProcess = c("center", "scale"),
trControl = trainControl(method = "boot", number = 25)
)
The model tunes K parameter automatically and choose value that minimizes the root mean squared (RMSE)
##Evaluate Predict on test data Post trained, the model is use to predict on the test set and compute the common regression metrics, RMSE, R-squared, and MAE.
knnPred <- predict(knnModel, newdata = testDF)
# test performance evaluating
postResample(pred = knnPred, obs = testData$y)
## RMSE Rsquared MAE
## 3.1172319 0.6556622 2.4899907
Plotting Prediction vs True Values
plot(testData$y, knnPred,
xlab = "True Values", ylab = "Predicted Values",
main = "k-NN Predictions vs Truth")
abline(0, 1, col = "gold")
## Implimentation
marsModel <- train(
y ~ .,
data = trainingDF,
method = "earth",
trControl = trainControl(method = "cv", number = 10)
)
# Predict on the test set
marsPred <- predict(marsModel, newdata = testDF)
# Evaluate
postResample(pred = marsPred, obs = testData$y)
## RMSE Rsquared MAE
## 1.8136467 0.8677298 1.3911836
# Plot variable importance for MARS
plot(varImp(marsModel), top = 10,
main = "Variable Importance (MARS)")
plot(testData$y, marsPred,
xlab = "True Values", ylab = "Predicted Values",
main = "MARS Predictions vs Truth",
col = "skyblue", pch = 16)
abline(0, 1, col = "red", lwd = 2)
The following will assit on diagnose model fit. resisual is centered around 0 with no heavy pattern.
residuals <- testData$y - knnPred
plot(knnPred, residuals,
xlab = "Predicted Values",
ylab = "Residuals",
main = "Residuals vs Predicted",
col = "lightgreen", pch = 16)
abline(h = 0, col = "black", lty = 2)
## Visual Density Plot of Residual
ggplot(data.frame(residuals), aes(x = residuals)) +
geom_density(fill = "beige", alpha = 0.6) +
geom_vline(xintercept = 0, linetype = "dashed", color = "orange") +
labs(title = "Density Plot of Residuals")
## Incorperating MARS model
marsModel <- train(
y ~ .,
data = trainingDF,
method = "earth",
trControl = trainControl(method = "cv", number = 10)
)
plot(varImp(marsModel), top = 10)
trainingDF$y_bin <- cut(trainingDF$y, breaks = 3)
featurePlot(x = trainingDF[, 1:2], y = trainingDF$y_bin,
plot = "box")
modelCompare <- data.frame(
True = testData$y,
kNN = knnPred,
MARS = marsPred
)
ggplot(modelCompare, aes(x = knnPred, y = marsPred)) +
geom_point(alpha = 0.4, color = "lightblue") +
geom_abline(slope = 1, intercept = 0, color = "tomato", linetype = "dashed") +
labs(title = "Comparison: k-NN vs MARS Predictions",
x = "k-NN Prediction", y = "MARS Prediction")
**It is not recommended to completely replace the permeability laboratory experiment with a model based only on simulated data, even though both the k-NN and MARS models showed respectable prediction accuracy on the test data, with RMSE values above 3.1 and R-squared values over 0.65.
Data with established structure and little noise is used to train these models. There are more intricate and unidentified causes of variance in real-world contexts. However, by focusing on important factors or predicting results beforehand, these models may be utilized as efficient screening tools or to enhance experimental design, potentially lowering the number of trials required.
In conclusion, models such as MARS can supplement empirical lab testing in permeability analysis, but they cannot completely replace it.**
A dataset that mimics a chemical production process is being used in this exercise. To determine which works best, we will use different nonlinear regression models after performing data imputation, splitting, and pre-processing as described in the exercise.
**Which nonlinear regression model gives the optimal resampling and test set performance?
A number of models, including MARS, k-NN (again), and maybe random forest (if required), will be trained and assessed. Following training, we’ll use resampling techniques to evaluate performance and verify results on the test set.
library(randomForest)
## randomForest 4.7-1.2
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
##
## margin
library(AppliedPredictiveModeling)
library(RANN)
set.seed(200)
data("ChemicalManufacturingProcess")
summary(ChemicalManufacturingProcess)
## Yield BiologicalMaterial01 BiologicalMaterial02 BiologicalMaterial03
## Min. :35.25 Min. :4.580 Min. :46.87 Min. :56.97
## 1st Qu.:38.75 1st Qu.:5.978 1st Qu.:52.68 1st Qu.:64.98
## Median :39.97 Median :6.305 Median :55.09 Median :67.22
## Mean :40.18 Mean :6.411 Mean :55.69 Mean :67.70
## 3rd Qu.:41.48 3rd Qu.:6.870 3rd Qu.:58.74 3rd Qu.:70.43
## Max. :46.34 Max. :8.810 Max. :64.75 Max. :78.25
##
## BiologicalMaterial04 BiologicalMaterial05 BiologicalMaterial06
## Min. : 9.38 Min. :13.24 Min. :40.60
## 1st Qu.:11.24 1st Qu.:17.23 1st Qu.:46.05
## Median :12.10 Median :18.49 Median :48.46
## Mean :12.35 Mean :18.60 Mean :48.91
## 3rd Qu.:13.22 3rd Qu.:19.90 3rd Qu.:51.34
## Max. :23.09 Max. :24.85 Max. :59.38
##
## BiologicalMaterial07 BiologicalMaterial08 BiologicalMaterial09
## Min. :100.0 Min. :15.88 Min. :11.44
## 1st Qu.:100.0 1st Qu.:17.06 1st Qu.:12.60
## Median :100.0 Median :17.51 Median :12.84
## Mean :100.0 Mean :17.49 Mean :12.85
## 3rd Qu.:100.0 3rd Qu.:17.88 3rd Qu.:13.13
## Max. :100.8 Max. :19.14 Max. :14.08
##
## BiologicalMaterial10 BiologicalMaterial11 BiologicalMaterial12
## Min. :1.770 Min. :135.8 Min. :18.35
## 1st Qu.:2.460 1st Qu.:143.8 1st Qu.:19.73
## Median :2.710 Median :146.1 Median :20.12
## Mean :2.801 Mean :147.0 Mean :20.20
## 3rd Qu.:2.990 3rd Qu.:149.6 3rd Qu.:20.75
## Max. :6.870 Max. :158.7 Max. :22.21
##
## ManufacturingProcess01 ManufacturingProcess02 ManufacturingProcess03
## Min. : 0.00 Min. : 0.00 Min. :1.47
## 1st Qu.:10.80 1st Qu.:19.30 1st Qu.:1.53
## Median :11.40 Median :21.00 Median :1.54
## Mean :11.21 Mean :16.68 Mean :1.54
## 3rd Qu.:12.15 3rd Qu.:21.50 3rd Qu.:1.55
## Max. :14.10 Max. :22.50 Max. :1.60
## NA's :1 NA's :3 NA's :15
## ManufacturingProcess04 ManufacturingProcess05 ManufacturingProcess06
## Min. :911.0 Min. : 923.0 Min. :203.0
## 1st Qu.:928.0 1st Qu.: 986.8 1st Qu.:205.7
## Median :934.0 Median : 999.2 Median :206.8
## Mean :931.9 Mean :1001.7 Mean :207.4
## 3rd Qu.:936.0 3rd Qu.:1008.9 3rd Qu.:208.7
## Max. :946.0 Max. :1175.3 Max. :227.4
## NA's :1 NA's :1 NA's :2
## ManufacturingProcess07 ManufacturingProcess08 ManufacturingProcess09
## Min. :177.0 Min. :177.0 Min. :38.89
## 1st Qu.:177.0 1st Qu.:177.0 1st Qu.:44.89
## Median :177.0 Median :178.0 Median :45.73
## Mean :177.5 Mean :177.6 Mean :45.66
## 3rd Qu.:178.0 3rd Qu.:178.0 3rd Qu.:46.52
## Max. :178.0 Max. :178.0 Max. :49.36
## NA's :1 NA's :1
## ManufacturingProcess10 ManufacturingProcess11 ManufacturingProcess12
## Min. : 7.500 Min. : 7.500 Min. : 0.0
## 1st Qu.: 8.700 1st Qu.: 9.000 1st Qu.: 0.0
## Median : 9.100 Median : 9.400 Median : 0.0
## Mean : 9.179 Mean : 9.386 Mean : 857.8
## 3rd Qu.: 9.550 3rd Qu.: 9.900 3rd Qu.: 0.0
## Max. :11.600 Max. :11.500 Max. :4549.0
## NA's :9 NA's :10 NA's :1
## ManufacturingProcess13 ManufacturingProcess14 ManufacturingProcess15
## Min. :32.10 Min. :4701 Min. :5904
## 1st Qu.:33.90 1st Qu.:4828 1st Qu.:6010
## Median :34.60 Median :4856 Median :6032
## Mean :34.51 Mean :4854 Mean :6039
## 3rd Qu.:35.20 3rd Qu.:4882 3rd Qu.:6061
## Max. :38.60 Max. :5055 Max. :6233
## NA's :1
## ManufacturingProcess16 ManufacturingProcess17 ManufacturingProcess18
## Min. : 0 Min. :31.30 Min. : 0
## 1st Qu.:4561 1st Qu.:33.50 1st Qu.:4813
## Median :4588 Median :34.40 Median :4835
## Mean :4566 Mean :34.34 Mean :4810
## 3rd Qu.:4619 3rd Qu.:35.10 3rd Qu.:4862
## Max. :4852 Max. :40.00 Max. :4971
##
## ManufacturingProcess19 ManufacturingProcess20 ManufacturingProcess21
## Min. :5890 Min. : 0 Min. :-1.8000
## 1st Qu.:6001 1st Qu.:4553 1st Qu.:-0.6000
## Median :6022 Median :4582 Median :-0.3000
## Mean :6028 Mean :4556 Mean :-0.1642
## 3rd Qu.:6050 3rd Qu.:4610 3rd Qu.: 0.0000
## Max. :6146 Max. :4759 Max. : 3.6000
##
## ManufacturingProcess22 ManufacturingProcess23 ManufacturingProcess24
## Min. : 0.000 Min. :0.000 Min. : 0.000
## 1st Qu.: 3.000 1st Qu.:2.000 1st Qu.: 4.000
## Median : 5.000 Median :3.000 Median : 8.000
## Mean : 5.406 Mean :3.017 Mean : 8.834
## 3rd Qu.: 8.000 3rd Qu.:4.000 3rd Qu.:14.000
## Max. :12.000 Max. :6.000 Max. :23.000
## NA's :1 NA's :1 NA's :1
## ManufacturingProcess25 ManufacturingProcess26 ManufacturingProcess27
## Min. : 0 Min. : 0 Min. : 0
## 1st Qu.:4832 1st Qu.:6020 1st Qu.:4560
## Median :4855 Median :6047 Median :4587
## Mean :4828 Mean :6016 Mean :4563
## 3rd Qu.:4877 3rd Qu.:6070 3rd Qu.:4609
## Max. :4990 Max. :6161 Max. :4710
## NA's :5 NA's :5 NA's :5
## ManufacturingProcess28 ManufacturingProcess29 ManufacturingProcess30
## Min. : 0.000 Min. : 0.00 Min. : 0.000
## 1st Qu.: 0.000 1st Qu.:19.70 1st Qu.: 8.800
## Median :10.400 Median :19.90 Median : 9.100
## Mean : 6.592 Mean :20.01 Mean : 9.161
## 3rd Qu.:10.750 3rd Qu.:20.40 3rd Qu.: 9.700
## Max. :11.500 Max. :22.00 Max. :11.200
## NA's :5 NA's :5 NA's :5
## ManufacturingProcess31 ManufacturingProcess32 ManufacturingProcess33
## Min. : 0.00 Min. :143.0 Min. :56.00
## 1st Qu.:70.10 1st Qu.:155.0 1st Qu.:62.00
## Median :70.80 Median :158.0 Median :64.00
## Mean :70.18 Mean :158.5 Mean :63.54
## 3rd Qu.:71.40 3rd Qu.:162.0 3rd Qu.:65.00
## Max. :72.50 Max. :173.0 Max. :70.00
## NA's :5 NA's :5
## ManufacturingProcess34 ManufacturingProcess35 ManufacturingProcess36
## Min. :2.300 Min. :463.0 Min. :0.01700
## 1st Qu.:2.500 1st Qu.:490.0 1st Qu.:0.01900
## Median :2.500 Median :495.0 Median :0.02000
## Mean :2.494 Mean :495.6 Mean :0.01957
## 3rd Qu.:2.500 3rd Qu.:501.5 3rd Qu.:0.02000
## Max. :2.600 Max. :522.0 Max. :0.02200
## NA's :5 NA's :5 NA's :5
## ManufacturingProcess37 ManufacturingProcess38 ManufacturingProcess39
## Min. :0.000 Min. :0.000 Min. :0.000
## 1st Qu.:0.700 1st Qu.:2.000 1st Qu.:7.100
## Median :1.000 Median :3.000 Median :7.200
## Mean :1.014 Mean :2.534 Mean :6.851
## 3rd Qu.:1.300 3rd Qu.:3.000 3rd Qu.:7.300
## Max. :2.300 Max. :3.000 Max. :7.500
##
## ManufacturingProcess40 ManufacturingProcess41 ManufacturingProcess42
## Min. :0.00000 Min. :0.00000 Min. : 0.00
## 1st Qu.:0.00000 1st Qu.:0.00000 1st Qu.:11.40
## Median :0.00000 Median :0.00000 Median :11.60
## Mean :0.01771 Mean :0.02371 Mean :11.21
## 3rd Qu.:0.00000 3rd Qu.:0.00000 3rd Qu.:11.70
## Max. :0.10000 Max. :0.20000 Max. :12.10
## NA's :1 NA's :1
## ManufacturingProcess43 ManufacturingProcess44 ManufacturingProcess45
## Min. : 0.0000 Min. :0.000 Min. :0.000
## 1st Qu.: 0.6000 1st Qu.:1.800 1st Qu.:2.100
## Median : 0.8000 Median :1.900 Median :2.200
## Mean : 0.9119 Mean :1.805 Mean :2.138
## 3rd Qu.: 1.0250 3rd Qu.:1.900 3rd Qu.:2.300
## Max. :11.0000 Max. :2.100 Max. :2.600
##
# Separate predictors and outcome
predictors <- ChemicalManufacturingProcess[, -1]
outcome <- ChemicalManufacturingProcess$Yield
# Imputation and scaling
preProc <- preProcess(predictors, method = c("knnImpute", "center", "scale"))
predictors_processed <- predict(preProc, newdata = predictors)
# Combine outcome with processed predictors
finalData <- data.frame(predictors_processed, Yield = outcome)
set.seed(123)
trainIndex <- createDataPartition(finalData$Yield, p = 0.8, list = FALSE)
trainData <- finalData[trainIndex, ]
testData <- finalData[-trainIndex, ]
preProcValues <- preProcess(trainData, method = c("center", "scale"))
trainData <- predict(preProcValues, trainData)
testData <- predict(preProcValues, testData)
Now will train models :k-NN, MARS, randomForest
#a k-NN models
set.seed(123)
knnModel <- train(
Yield ~ ., data = trainData,
method = "knn",
trControl = trainControl(method = "cv", number = 10)
)
#b MARS Model
set.seed(123)
marsModel <- train(
Yield ~ ., data = trainData,
method = "earth",
trControl = trainControl(method = "cv", number = 10)
)
#c RandomForest model
set.seed(123)
rfModel <- train(
Yield ~ ., data = trainData,
method = "rf",
trControl = trainControl(method = "cv", number = 10)
)
Now will predicat the values on the test data and evaluate performance of the models using the following metrics:RMSE, R-squared, and MAE
#a k-NN predictions and evaluation
knnPred <- predict(knnModel, newdata = testData)
postResample(pred = knnPred, obs = testData$Yield)
## RMSE Rsquared MAE
## 0.7560399 0.4284176 0.6249868
#b MARS predictions and evaluation
marsPred <- predict(marsModel, newdata = testData)
postResample(pred = marsPred, obs = testData$Yield)
## RMSE Rsquared MAE
## 0.7063439 0.5266952 0.5992609
#c Random Forest predictions and evaluation
rfPred <- predict(rfModel, newdata = testData)
postResample(pred = rfPred, obs = testData$Yield)
## RMSE Rsquared MAE
## 0.6879460 0.5309137 0.5155775
##(b) Which predictors are most important in the optimal nonlinear regression model?
Extracting the variables importance form the optimal model , ofcourse that wich provide the best performance;
# Variable importance for MARS
plot(varImp(marsModel), main = "Variable Importance for MARS")
plot(varImp(rfModel), main = "Variable Importance for Random Forest")
lets see how the top predictors relate to the response(Yield). We’ll use a scatterplot/partial dependence plots
# Get top 5 predictors from Random Forest
rf_importance <- varImp(rfModel)$importance
rf_sorted <- rf_importance[order(rf_importance$Overall, decreasing = TRUE), , drop = FALSE]
rf_top <- rownames(rf_sorted)[1:5]
# Plot predictor vs Yield
for (var in rf_top) {
p <- ggplot(trainData, aes_string(x = var, y = "Yield")) +
geom_point(alpha = 0.4, color = "darkgreen") +
geom_smooth(method = "loess", se = TRUE, color = "darkorange") +
labs(title = paste("Yield vs", var),
x = var,
y = "Yield") +
theme_minimal()
print(p)
}
## Warning: `aes_string()` was deprecated in ggplot2 3.0.0.
## ℹ Please use tidy evaluation idioms with `aes()`.
## ℹ See also `vignette("ggplot2-in-packages")` for more information.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'
# Extract top 5 most important predictors
rf_imp <- varImp(rfModel)$importance
rf_imp_sorted <- rf_imp[order(rf_imp$Overall, decreasing = TRUE), , drop = FALSE]
top_rf_predictors <- rownames(rf_imp_sorted)[1:5]
# Scatter plots for those predictors vs Yield
featurePlot(x = trainData[, top_rf_predictors],
y = trainData$Yield,
plot = "scatter",
auto.key = list(columns = 2))