Nonlinear Reagression Simulation:Friedman’s Benchmark Problem 7.2 &7.5

Getting Started

##Introduction Friedman (1991) provided a collection of benchmark datasets for evaluating and comparing regression models. One of the most commonly used examples is synthesized from the following nonlinear equation: y=10sin⁡(πx1x2)+20(x3−0.5) 2+10x4+5x5+N(0,σ2) y = 10 (x_1 x_2) + 20 (x_3 - 0.5). ^2 = 10x_4 + 5x_5 + (0, ^2) y=10sin(πx1x2)+20(x3−0.5) 2+10x4+5x5+N(0,σ2) In this equation, predictors x1x_1x1 to x5x_5x5 are informative, whereas x6x_6x6 to x10x_{10}x10 are noisy variables taken uniformly from [0,1][0, 1][0,1]. This simulation allows us to assess how effectively models identify nonlinearity while ignoring irrelevant factors.

Load Packages

library(mlbench)
library(caret)
## Loading required package: ggplot2
## Loading required package: lattice
library(ggplot2)
library(earth)
## Loading required package: Formula
## Loading required package: plotmo
## Loading required package: plotrix

Stimulate Training Data

Starting generating a traning dataset with 200 observations and predictors.

set.seed(200)
trainingData <- mlbench.friedman1(200, sd = 1)
trainingData$x <- data.frame(trainingData$x)
trainingDF <- data.frame(trainingData$x, y = trainingData$y)

This dataset offers a controlled environment for evaluating model performance. The response y is designed to follow a known nonlinear connection, allowing us to evaluate model correctness in a meaningful way.

Stimulate Test Data

A test wider dataset is constructed to estimate the models’ real generalization error.

testData <- mlbench.friedman1(5000, sd = 1)
testData$x <- data.frame(testData$x)
testDF <- data.frame(testData$x)

Evaluation & Model Tuning

A k-nearest neighbors (k-NN)regression model using caret packages is the start ideal for fitting; Model is trained using boostrap resampling(25 reps) centered and scaled during reprocessing.

set.seed(123) 
knnModel <- train(
  y ~ .,
  data = trainingDF,
  method = "knn",
  preProcess = c("center", "scale"),
  trControl = trainControl(method = "boot", number = 25)
)

The model tunes K parameter automatically and choose value that minimizes the root mean squared (RMSE)

##Evaluate Predict on test data Post trained, the model is use to predict on the test set and compute the common regression metrics, RMSE, R-squared, and MAE.

knnPred <- predict(knnModel, newdata = testDF)

# test performance evaluating

postResample(pred = knnPred, obs = testData$y)
##      RMSE  Rsquared       MAE 
## 3.1172319 0.6556622 2.4899907

Visualization

Plotting Prediction vs True Values

plot(testData$y, knnPred, 
     xlab = "True Values", ylab = "Predicted Values", 
     main = "k-NN Predictions vs Truth")
abline(0, 1, col = "gold")

## Implimentation

marsModel <- train(
  y ~ .,
  data = trainingDF,
  method = "earth",
  trControl = trainControl(method = "cv", number = 10)
)

# Predict on the test set
marsPred <- predict(marsModel, newdata = testDF)

# Evaluate
postResample(pred = marsPred, obs = testData$y)
##      RMSE  Rsquared       MAE 
## 1.8136467 0.8677298 1.3911836
# Plot variable importance for MARS
plot(varImp(marsModel), top = 10, 
     main = "Variable Importance (MARS)")

plot(testData$y, marsPred,
     xlab = "True Values", ylab = "Predicted Values",
     main = "MARS Predictions vs Truth",
     col = "skyblue", pch = 16)
abline(0, 1, col = "red", lwd = 2)

Residual Plot

The following will assit on diagnose model fit. resisual is centered around 0 with no heavy pattern.

residuals <- testData$y - knnPred

plot(knnPred, residuals,
     xlab = "Predicted Values",
     ylab = "Residuals",
     main = "Residuals vs Predicted",
     col = "lightgreen", pch = 16)
abline(h = 0, col = "black", lty = 2)

## Visual Density Plot of Residual

ggplot(data.frame(residuals), aes(x = residuals)) +
  geom_density(fill = "beige", alpha = 0.6) +
  geom_vline(xintercept = 0, linetype = "dashed", color = "orange") +
  labs(title = "Density Plot of Residuals")

## Incorperating MARS model

marsModel <- train(
  y ~ .,
  data = trainingDF,
  method = "earth",
  trControl = trainControl(method = "cv", number = 10)
)

plot(varImp(marsModel), top = 10)

Final Visual using Boxplot of Predictors vs Binned Response

trainingDF$y_bin <- cut(trainingDF$y, breaks = 3)

featurePlot(x = trainingDF[, 1:2], y = trainingDF$y_bin, 
            plot = "box")

Compare K-NN vs MARS

modelCompare <- data.frame(
  True = testData$y,
  kNN = knnPred,
  MARS = marsPred
)

ggplot(modelCompare, aes(x = knnPred, y = marsPred)) +
  geom_point(alpha = 0.4, color = "lightblue") +
  geom_abline(slope = 1, intercept = 0, color = "tomato", linetype = "dashed") +
  labs(title = "Comparison: k-NN vs MARS Predictions",
       x = "k-NN Prediction", y = "MARS Prediction")

Would you recommend any of the models you have developed to replace the permeability laboratory experiment?

**It is not recommended to completely replace the permeability laboratory experiment with a model based only on simulated data, even though both the k-NN and MARS models showed respectable prediction accuracy on the test data, with RMSE values above 3.1 and R-squared values over 0.65.

Data with established structure and little noise is used to train these models. There are more intricate and unidentified causes of variance in real-world contexts. However, by focusing on important factors or predicting results beforehand, these models may be utilized as efficient screening tools or to enhance experimental design, potentially lowering the number of trials required.

In conclusion, models such as MARS can supplement empirical lab testing in permeability analysis, but they cannot completely replace it.**

Nonlinear Regression Chemical Manufacturing Proces Data

A dataset that mimics a chemical production process is being used in this exercise. To determine which works best, we will use different nonlinear regression models after performing data imputation, splitting, and pre-processing as described in the exercise.

**Which nonlinear regression model gives the optimal resampling and test set performance?

A number of models, including MARS, k-NN (again), and maybe random forest (if required), will be trained and assessed. Following training, we’ll use resampling techniques to evaluate performance and verify results on the test set.

Load Add Required Packages

library(randomForest)
## randomForest 4.7-1.2
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
## 
##     margin
library(AppliedPredictiveModeling)
library(RANN)

Processing & Splitting data

set.seed(200)
data("ChemicalManufacturingProcess") 

Explore Missing Data

summary(ChemicalManufacturingProcess)
##      Yield       BiologicalMaterial01 BiologicalMaterial02 BiologicalMaterial03
##  Min.   :35.25   Min.   :4.580        Min.   :46.87        Min.   :56.97       
##  1st Qu.:38.75   1st Qu.:5.978        1st Qu.:52.68        1st Qu.:64.98       
##  Median :39.97   Median :6.305        Median :55.09        Median :67.22       
##  Mean   :40.18   Mean   :6.411        Mean   :55.69        Mean   :67.70       
##  3rd Qu.:41.48   3rd Qu.:6.870        3rd Qu.:58.74        3rd Qu.:70.43       
##  Max.   :46.34   Max.   :8.810        Max.   :64.75        Max.   :78.25       
##                                                                                
##  BiologicalMaterial04 BiologicalMaterial05 BiologicalMaterial06
##  Min.   : 9.38        Min.   :13.24        Min.   :40.60       
##  1st Qu.:11.24        1st Qu.:17.23        1st Qu.:46.05       
##  Median :12.10        Median :18.49        Median :48.46       
##  Mean   :12.35        Mean   :18.60        Mean   :48.91       
##  3rd Qu.:13.22        3rd Qu.:19.90        3rd Qu.:51.34       
##  Max.   :23.09        Max.   :24.85        Max.   :59.38       
##                                                                
##  BiologicalMaterial07 BiologicalMaterial08 BiologicalMaterial09
##  Min.   :100.0        Min.   :15.88        Min.   :11.44       
##  1st Qu.:100.0        1st Qu.:17.06        1st Qu.:12.60       
##  Median :100.0        Median :17.51        Median :12.84       
##  Mean   :100.0        Mean   :17.49        Mean   :12.85       
##  3rd Qu.:100.0        3rd Qu.:17.88        3rd Qu.:13.13       
##  Max.   :100.8        Max.   :19.14        Max.   :14.08       
##                                                                
##  BiologicalMaterial10 BiologicalMaterial11 BiologicalMaterial12
##  Min.   :1.770        Min.   :135.8        Min.   :18.35       
##  1st Qu.:2.460        1st Qu.:143.8        1st Qu.:19.73       
##  Median :2.710        Median :146.1        Median :20.12       
##  Mean   :2.801        Mean   :147.0        Mean   :20.20       
##  3rd Qu.:2.990        3rd Qu.:149.6        3rd Qu.:20.75       
##  Max.   :6.870        Max.   :158.7        Max.   :22.21       
##                                                                
##  ManufacturingProcess01 ManufacturingProcess02 ManufacturingProcess03
##  Min.   : 0.00          Min.   : 0.00          Min.   :1.47          
##  1st Qu.:10.80          1st Qu.:19.30          1st Qu.:1.53          
##  Median :11.40          Median :21.00          Median :1.54          
##  Mean   :11.21          Mean   :16.68          Mean   :1.54          
##  3rd Qu.:12.15          3rd Qu.:21.50          3rd Qu.:1.55          
##  Max.   :14.10          Max.   :22.50          Max.   :1.60          
##  NA's   :1              NA's   :3              NA's   :15            
##  ManufacturingProcess04 ManufacturingProcess05 ManufacturingProcess06
##  Min.   :911.0          Min.   : 923.0         Min.   :203.0         
##  1st Qu.:928.0          1st Qu.: 986.8         1st Qu.:205.7         
##  Median :934.0          Median : 999.2         Median :206.8         
##  Mean   :931.9          Mean   :1001.7         Mean   :207.4         
##  3rd Qu.:936.0          3rd Qu.:1008.9         3rd Qu.:208.7         
##  Max.   :946.0          Max.   :1175.3         Max.   :227.4         
##  NA's   :1              NA's   :1              NA's   :2             
##  ManufacturingProcess07 ManufacturingProcess08 ManufacturingProcess09
##  Min.   :177.0          Min.   :177.0          Min.   :38.89         
##  1st Qu.:177.0          1st Qu.:177.0          1st Qu.:44.89         
##  Median :177.0          Median :178.0          Median :45.73         
##  Mean   :177.5          Mean   :177.6          Mean   :45.66         
##  3rd Qu.:178.0          3rd Qu.:178.0          3rd Qu.:46.52         
##  Max.   :178.0          Max.   :178.0          Max.   :49.36         
##  NA's   :1              NA's   :1                                    
##  ManufacturingProcess10 ManufacturingProcess11 ManufacturingProcess12
##  Min.   : 7.500         Min.   : 7.500         Min.   :   0.0        
##  1st Qu.: 8.700         1st Qu.: 9.000         1st Qu.:   0.0        
##  Median : 9.100         Median : 9.400         Median :   0.0        
##  Mean   : 9.179         Mean   : 9.386         Mean   : 857.8        
##  3rd Qu.: 9.550         3rd Qu.: 9.900         3rd Qu.:   0.0        
##  Max.   :11.600         Max.   :11.500         Max.   :4549.0        
##  NA's   :9              NA's   :10             NA's   :1             
##  ManufacturingProcess13 ManufacturingProcess14 ManufacturingProcess15
##  Min.   :32.10          Min.   :4701           Min.   :5904          
##  1st Qu.:33.90          1st Qu.:4828           1st Qu.:6010          
##  Median :34.60          Median :4856           Median :6032          
##  Mean   :34.51          Mean   :4854           Mean   :6039          
##  3rd Qu.:35.20          3rd Qu.:4882           3rd Qu.:6061          
##  Max.   :38.60          Max.   :5055           Max.   :6233          
##                         NA's   :1                                    
##  ManufacturingProcess16 ManufacturingProcess17 ManufacturingProcess18
##  Min.   :   0           Min.   :31.30          Min.   :   0          
##  1st Qu.:4561           1st Qu.:33.50          1st Qu.:4813          
##  Median :4588           Median :34.40          Median :4835          
##  Mean   :4566           Mean   :34.34          Mean   :4810          
##  3rd Qu.:4619           3rd Qu.:35.10          3rd Qu.:4862          
##  Max.   :4852           Max.   :40.00          Max.   :4971          
##                                                                      
##  ManufacturingProcess19 ManufacturingProcess20 ManufacturingProcess21
##  Min.   :5890           Min.   :   0           Min.   :-1.8000       
##  1st Qu.:6001           1st Qu.:4553           1st Qu.:-0.6000       
##  Median :6022           Median :4582           Median :-0.3000       
##  Mean   :6028           Mean   :4556           Mean   :-0.1642       
##  3rd Qu.:6050           3rd Qu.:4610           3rd Qu.: 0.0000       
##  Max.   :6146           Max.   :4759           Max.   : 3.6000       
##                                                                      
##  ManufacturingProcess22 ManufacturingProcess23 ManufacturingProcess24
##  Min.   : 0.000         Min.   :0.000          Min.   : 0.000        
##  1st Qu.: 3.000         1st Qu.:2.000          1st Qu.: 4.000        
##  Median : 5.000         Median :3.000          Median : 8.000        
##  Mean   : 5.406         Mean   :3.017          Mean   : 8.834        
##  3rd Qu.: 8.000         3rd Qu.:4.000          3rd Qu.:14.000        
##  Max.   :12.000         Max.   :6.000          Max.   :23.000        
##  NA's   :1              NA's   :1              NA's   :1             
##  ManufacturingProcess25 ManufacturingProcess26 ManufacturingProcess27
##  Min.   :   0           Min.   :   0           Min.   :   0          
##  1st Qu.:4832           1st Qu.:6020           1st Qu.:4560          
##  Median :4855           Median :6047           Median :4587          
##  Mean   :4828           Mean   :6016           Mean   :4563          
##  3rd Qu.:4877           3rd Qu.:6070           3rd Qu.:4609          
##  Max.   :4990           Max.   :6161           Max.   :4710          
##  NA's   :5              NA's   :5              NA's   :5             
##  ManufacturingProcess28 ManufacturingProcess29 ManufacturingProcess30
##  Min.   : 0.000         Min.   : 0.00          Min.   : 0.000        
##  1st Qu.: 0.000         1st Qu.:19.70          1st Qu.: 8.800        
##  Median :10.400         Median :19.90          Median : 9.100        
##  Mean   : 6.592         Mean   :20.01          Mean   : 9.161        
##  3rd Qu.:10.750         3rd Qu.:20.40          3rd Qu.: 9.700        
##  Max.   :11.500         Max.   :22.00          Max.   :11.200        
##  NA's   :5              NA's   :5              NA's   :5             
##  ManufacturingProcess31 ManufacturingProcess32 ManufacturingProcess33
##  Min.   : 0.00          Min.   :143.0          Min.   :56.00         
##  1st Qu.:70.10          1st Qu.:155.0          1st Qu.:62.00         
##  Median :70.80          Median :158.0          Median :64.00         
##  Mean   :70.18          Mean   :158.5          Mean   :63.54         
##  3rd Qu.:71.40          3rd Qu.:162.0          3rd Qu.:65.00         
##  Max.   :72.50          Max.   :173.0          Max.   :70.00         
##  NA's   :5                                     NA's   :5             
##  ManufacturingProcess34 ManufacturingProcess35 ManufacturingProcess36
##  Min.   :2.300          Min.   :463.0          Min.   :0.01700       
##  1st Qu.:2.500          1st Qu.:490.0          1st Qu.:0.01900       
##  Median :2.500          Median :495.0          Median :0.02000       
##  Mean   :2.494          Mean   :495.6          Mean   :0.01957       
##  3rd Qu.:2.500          3rd Qu.:501.5          3rd Qu.:0.02000       
##  Max.   :2.600          Max.   :522.0          Max.   :0.02200       
##  NA's   :5              NA's   :5              NA's   :5             
##  ManufacturingProcess37 ManufacturingProcess38 ManufacturingProcess39
##  Min.   :0.000          Min.   :0.000          Min.   :0.000         
##  1st Qu.:0.700          1st Qu.:2.000          1st Qu.:7.100         
##  Median :1.000          Median :3.000          Median :7.200         
##  Mean   :1.014          Mean   :2.534          Mean   :6.851         
##  3rd Qu.:1.300          3rd Qu.:3.000          3rd Qu.:7.300         
##  Max.   :2.300          Max.   :3.000          Max.   :7.500         
##                                                                      
##  ManufacturingProcess40 ManufacturingProcess41 ManufacturingProcess42
##  Min.   :0.00000        Min.   :0.00000        Min.   : 0.00         
##  1st Qu.:0.00000        1st Qu.:0.00000        1st Qu.:11.40         
##  Median :0.00000        Median :0.00000        Median :11.60         
##  Mean   :0.01771        Mean   :0.02371        Mean   :11.21         
##  3rd Qu.:0.00000        3rd Qu.:0.00000        3rd Qu.:11.70         
##  Max.   :0.10000        Max.   :0.20000        Max.   :12.10         
##  NA's   :1              NA's   :1                                    
##  ManufacturingProcess43 ManufacturingProcess44 ManufacturingProcess45
##  Min.   : 0.0000        Min.   :0.000          Min.   :0.000         
##  1st Qu.: 0.6000        1st Qu.:1.800          1st Qu.:2.100         
##  Median : 0.8000        Median :1.900          Median :2.200         
##  Mean   : 0.9119        Mean   :1.805          Mean   :2.138         
##  3rd Qu.: 1.0250        3rd Qu.:1.900          3rd Qu.:2.300         
##  Max.   :11.0000        Max.   :2.100          Max.   :2.600         
## 

Inputing Missing Data

# Separate predictors and outcome
predictors <- ChemicalManufacturingProcess[, -1]
outcome <- ChemicalManufacturingProcess$Yield

# Imputation and scaling
preProc <- preProcess(predictors, method = c("knnImpute", "center", "scale"))

predictors_processed <- predict(preProc, newdata = predictors)

# Combine outcome with processed predictors
finalData <- data.frame(predictors_processed, Yield = outcome)

Train-Test Split

set.seed(123)
trainIndex <- createDataPartition(finalData$Yield, p = 0.8, list = FALSE)
trainData <- finalData[trainIndex, ]
testData <- finalData[-trainIndex, ]

Preprocess Dataset for Centering & Scaling

preProcValues <- preProcess(trainData, method = c("center", "scale"))
trainData <- predict(preProcValues, trainData)
testData <- predict(preProcValues, testData)

Train Nonlinear Models

Now will train models :k-NN, MARS, randomForest

#a k-NN models
set.seed(123)
knnModel <- train(
  Yield ~ ., data = trainData,
  method = "knn",
  trControl = trainControl(method = "cv", number = 10)
)


#b MARS Model
set.seed(123)
marsModel <- train(
  Yield ~ ., data = trainData,
  method = "earth",
  trControl = trainControl(method = "cv", number = 10)
)


#c RandomForest model
set.seed(123)
rfModel <- train(
  Yield ~ ., data = trainData,
  method = "rf",
  trControl = trainControl(method = "cv", number = 10)
)

Evaluate the Models on Test Set

Now will predicat the values on the test data and evaluate performance of the models using the following metrics:RMSE, R-squared, and MAE

#a k-NN predictions and evaluation
knnPred <- predict(knnModel, newdata = testData)
postResample(pred = knnPred, obs = testData$Yield)
##      RMSE  Rsquared       MAE 
## 0.7560399 0.4284176 0.6249868
#b MARS predictions and evaluation
marsPred <- predict(marsModel, newdata = testData)
postResample(pred = marsPred, obs = testData$Yield)
##      RMSE  Rsquared       MAE 
## 0.7063439 0.5266952 0.5992609
#c Random Forest predictions and evaluation
rfPred <- predict(rfModel, newdata = testData)
postResample(pred = rfPred, obs = testData$Yield)
##      RMSE  Rsquared       MAE 
## 0.6879460 0.5309137 0.5155775

##(b) Which predictors are most important in the optimal nonlinear regression model?

Extracting the variables importance form the optimal model , ofcourse that wich provide the best performance;

# Variable importance for MARS
plot(varImp(marsModel), main = "Variable Importance for MARS")

Variable importance for Random Forest

plot(varImp(rfModel), main = "Variable Importance for Random Forest")

(c) Now to explore relationship among the top predictors and the response

lets see how the top predictors relate to the response(Yield). We’ll use a scatterplot/partial dependence plots

# Get top 5 predictors from Random Forest
rf_importance <- varImp(rfModel)$importance
rf_sorted <- rf_importance[order(rf_importance$Overall, decreasing = TRUE), , drop = FALSE]
rf_top <- rownames(rf_sorted)[1:5]
# Plot predictor vs Yield
for (var in rf_top) {
  p <- ggplot(trainData, aes_string(x = var, y = "Yield")) +
    geom_point(alpha = 0.4, color = "darkgreen") +
    geom_smooth(method = "loess", se = TRUE, color = "darkorange") +
    labs(title = paste("Yield vs", var),
         x = var,
         y = "Yield") +
    theme_minimal()
  print(p)
}
## Warning: `aes_string()` was deprecated in ggplot2 3.0.0.
## ℹ Please use tidy evaluation idioms with `aes()`.
## ℹ See also `vignette("ggplot2-in-packages")` for more information.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## `geom_smooth()` using formula = 'y ~ x'

## `geom_smooth()` using formula = 'y ~ x'

## `geom_smooth()` using formula = 'y ~ x'

## `geom_smooth()` using formula = 'y ~ x'

## `geom_smooth()` using formula = 'y ~ x'

# Extract top 5 most important predictors
rf_imp <- varImp(rfModel)$importance
rf_imp_sorted <- rf_imp[order(rf_imp$Overall, decreasing = TRUE), , drop = FALSE]
top_rf_predictors <- rownames(rf_imp_sorted)[1:5]

# Scatter plots for those predictors vs Yield
featurePlot(x = trainData[, top_rf_predictors], 
            y = trainData$Yield, 
            plot = "scatter",
            auto.key = list(columns = 2))