Case study X5: Deep learning applied to Regression problem

Foreword: About the Machine Learning in Medicine (MLM) project

The MLM project has been initialized in 2016 and aims to:

1.Encourage using Machine Learning techniques in medical research in Vietnam 2.Promote the use of R statistical programming language, an open source and leading tool for practicing data science.

Background

In the last Case study X4, We have seen the power of Deep learning method applied to a binary classfication problem. Through this case X5, we will introduce to you another application of Deep learning method for developing a predictive model for a numerical response variable. We call it a regression task. Most of predictive models for physiological and clinical parameters are based on linear regression method but as you will see below, blackbox algorithms could handle the regression problems as well as the GLM based method and even better for the nonlinear variables.

The present study involves a pulmonary function parameter called DLCO or Lung diffusing capacity for carbon monoxide (CO). This is a clinical parameter for evaluating the gas exchange function of lungs. DLCO is measured by making the subject inspires a gas mixture that contains Helium and CO then measuring the partial pressure difference between inspired and expired CO after 10 seconds of breath-hold. DLCO is defined as volume of CO diffused into lung capillary blood during 1 minute for each pressure gradient unit (ml/min/mmHg or mmol/min/kPa). DLCO could help detecting the respiratory diseases such as COPD, lung fibrosis, pulmonary hypertension…

This case study implies a real dataset of DLCO measured in 531 healthy Caucasians. Our goal is to develop a predictive model that allows to estimate DLCO values by Gender (Male or Female), Height (cm) and Age (year). It could also be considered as a prediction of the mean DLCO value (mean predicted) of a virtual population of many peoples who are characterized by the same gender, age and height values.

Preparation

library(tidyverse)   #Loading tidyverse 

#Setting ggplot theme

my_theme <- function(base_size = 10, base_family = "sans"){
  theme_minimal(base_size = base_size, base_family = base_family) +
    theme(
      axis.text = element_text(size = 10),
      axis.text.x = element_text(angle = 0, vjust = 0.5, hjust = 0.5),
      axis.title = element_text(size = 12),
      panel.grid.major = element_line(color = "grey"),
      panel.grid.minor = element_blank(),
      panel.background = element_rect(fill = "#fdf9ff"),
      strip.background = element_rect(fill = "#400156", color = "#400156", size =0.5),
      strip.text = element_text(face = "bold", size = 10, color = "white"),
      legend.position = "bottom",
      legend.justification = "center",
      legend.background = element_blank(),
      panel.border = element_rect(color = "grey30", fill = NA, size = 0.5)
    )
}
theme_set(my_theme())

Data exploration

data<-read.csv("DLCO.csv",sep=";")%>%as_tibble()

data$Sex=recode_factor(data$Sex,`0` = "Female", `1` = "Male")

Hmisc::describe(data)
## data 
## 
##  4  Variables      531  Observations
## ---------------------------------------------------------------------------
## Sex 
##        n  missing distinct 
##      531        0        2 
##                         
## Value      Female   Male
## Frequency     264    267
## Proportion  0.497  0.503
## ---------------------------------------------------------------------------
## Age 
##        n  missing distinct     Info     Mean      Gmd      .05      .10 
##      531        0       70        1    44.48    19.35     22.0     23.0 
##      .25      .50      .75      .90      .95 
##     30.0     43.0     56.0     69.0     75.5 
## 
## lowest : 16 17 18 19 20, highest: 81 82 84 85 87
## ---------------------------------------------------------------------------
## Height 
##        n  missing distinct     Info     Mean      Gmd      .05      .10 
##      531        0       76    0.999    170.6    11.02      155      158 
##      .25      .50      .75      .90      .95 
##      163      171      177      182      188 
## 
## lowest : 147.0 149.0 151.0 151.5 152.0, highest: 193.0 194.0 195.0 196.0 200.0
## ---------------------------------------------------------------------------
## DLCO 
##        n  missing distinct     Info     Mean      Gmd      .05      .10 
##      531        0      257        1    29.91    8.814    17.55    20.00 
##      .25      .50      .75      .90      .95 
##    24.40    29.60    35.80    39.90    42.90 
## 
## lowest : 10.0 11.0 11.3 12.2 12.5, highest: 48.4 49.4 49.9 50.4 57.8
## ---------------------------------------------------------------------------

Based on Descriptive statistics results, our sample consists of adults, aged from 16 to 87 yrs old. The Height 95%CI was 155 to 188 cm. These parameters set the limits of prediction for our future predictive model, i.e our model wil be nomore reliable for children or patients who are shorter than 150 cm. Evaraged DLCO was 29.60 ml/min/mmHg (95%CI of 17.6 to 42.9). Two genders were equally distributed (49.7% Female, 50.3% Male).

data%>%ggplot(aes(x=Height,y=DLCO,fill=Sex))+geom_point(shape=21,size=2,color="black")+facet_wrap(~Sex,scales="free")+geom_smooth(aes(color=Sex))+scale_fill_brewer(palette = "Set1")+scale_color_brewer(palette = "Set1")+ggtitle("DLCO ~ Height")

data%>%ggplot(aes(x=Age,y=DLCO,fill=Sex))+geom_point(shape=21,size=2,color="black")+facet_wrap(~Sex,scales="free")+geom_smooth(aes(color=Sex))+scale_fill_brewer(palette = "Set1")+scale_color_brewer(palette = "Set1")+ggtitle("DLCO ~ Age")

data%>%ggplot(aes(x=DLCO,fill=Sex,color=Sex))+geom_histogram(alpha=0.5,bins = 100)+scale_fill_brewer(palette = "Set1")+scale_color_brewer(palette = "Set1")+ggtitle("DLCO distribution")+facet_wrap(~Sex,ncol=1)

Data visualisation showed a normal distribution of target variable within two gender groups. Females have lower DLCO than male.The relationship between DLCO and Height is almost linear and proportionate, but the association between DLCO and Height is more complicated. It could be foreseen that the nonlinear variability of DLCO in function of Age cannot be modelled by a simple linear model.

Like other lung function parameters, DLCO is characterized by a nonlinear growth curve in function of age and usually are not normally distributed in general population. When the target parameter is not normally distributed, logarithmic transformation (also known as Log-link function) is usually applied. Some authors also attempt to recapture the nonlinear growth curve by splitting their sample into two separated subsets for children and elder people, or by using polynomial equations. Despite these efforts, traditional models eventually fail to recapture the real relationship between lung function indices and age.

Initalisation of h2o

library(h2o)

h2o.init(nthreads = -1,max_mem_size ="4g")
##  Connection successful!
## 
## R is connected to the H2O cluster: 
##     H2O cluster uptime:         4 hours 58 minutes 
##     H2O cluster version:        3.10.3.6 
##     H2O cluster version age:    1 month and 24 days  
##     H2O cluster name:           H2O_started_from_R_Admin_guq467 
##     H2O cluster total nodes:    1 
##     H2O cluster total memory:   2.80 GB 
##     H2O cluster total cores:    4 
##     H2O cluster allowed cores:  4 
##     H2O cluster healthy:        TRUE 
##     H2O Connection ip:          localhost 
##     H2O Connection port:        54321 
##     H2O Connection proxy:       NA 
##     R Version:                  R version 3.3.1 (2016-06-21)

Converting dataframe to h2o frame and data splitting

wdata=as.h2o(data)
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
splits=h2o.splitFrame(wdata, ratios=c(0.75,0.125),seed=123)

wtrain=splits[[1]]
wvalid=splits[[2]]
wtest=splits[[3]]

wtrain
##    Sex Age Height DLCO
## 1 Male  18  177.5 35.4
## 2 Male  18  181.0 42.5
## 3 Male  18  176.0 37.0
## 4 Male  22  180.0 38.0
## 5 Male  23  185.0 37.1
## 6 Male  23  163.5 30.4
## 
## [391 rows x 4 columns]
wvalid
##    Sex Age Height DLCO
## 1 Male  22  171.0 38.0
## 2 Male  26  163.0 29.3
## 3 Male  39  163.5 37.4
## 4 Male  39  179.0 34.7
## 5 Male  61  174.0 22.1
## 6 Male  18  176.0 40.9
## 
## [72 rows x 4 columns]
wtest
##    Sex Age Height DLCO
## 1 Male  25  171.5 37.1
## 2 Male  25  174.0 36.7
## 3 Male  30  178.0 32.7
## 4 Male  47  167.0 25.4
## 5 Male  51  170.0 30.1
## 6 Male  85  166.0 18.2
## 
## [68 rows x 4 columns]
train=as.data.frame(wtrain)
test=as.data.frame(wtest)
valid=as.data.frame(wvalid)

wtrain
##    Sex Age Height DLCO
## 1 Male  18  177.5 35.4
## 2 Male  18  181.0 42.5
## 3 Male  18  176.0 37.0
## 4 Male  22  180.0 38.0
## 5 Male  23  185.0 37.1
## 6 Male  23  163.5 30.4
## 
## [391 rows x 4 columns]
wvalid
##    Sex Age Height DLCO
## 1 Male  22  171.0 38.0
## 2 Male  26  163.0 29.3
## 3 Male  39  163.5 37.4
## 4 Male  39  179.0 34.7
## 5 Male  61  174.0 22.1
## 6 Male  18  176.0 40.9
## 
## [72 rows x 4 columns]
wtest
##    Sex Age Height DLCO
## 1 Male  25  171.5 37.1
## 2 Male  25  174.0 36.7
## 3 Male  30  178.0 32.7
## 4 Male  47  167.0 25.4
## 5 Male  51  170.0 30.1
## 6 Male  85  166.0 18.2
## 
## [68 rows x 4 columns]
p1=train%>%ggplot(aes(x=DLCO,fill=Sex,color=Sex))+geom_density(alpha=0.5,show.legend = F)+scale_fill_brewer(palette = "Set1")+scale_color_brewer(palette = "Set1")+ggtitle("DLCO distribution")+ggtitle("Train")
p2=test%>%ggplot(aes(x=DLCO,fill=Sex,color=Sex))+geom_density(alpha=0.5,show.legend = F)+scale_fill_brewer(palette = "Set1")+scale_color_brewer(palette = "Set1")+ggtitle("DLCO distribution")+ggtitle("Test")
p3=valid%>%ggplot(aes(x=DLCO,fill=Sex,color=Sex))+geom_density(alpha=0.5,show.legend = F)+scale_fill_brewer(palette = "Set1")+scale_color_brewer(palette = "Set1")+ggtitle("DLCO distribution")+ggtitle("Validation")
p4=data%>%ggplot(aes(x=DLCO,fill=Sex,color=Sex))+geom_density(alpha=0.5,show.legend = F)+scale_fill_brewer(palette = "Set1")+scale_color_brewer(palette = "Set1")+ggtitle("DLCO distribution")+ggtitle("Origin_data")

library(gridExtra)
grid.arrange(p4,p1,p2,p3,ncol=2)

The original dataset was splitted into 3 parts:

A Train subset contains 391 cases A Validation subset (for calibrating the model) includes 72 cases Finally, a Test subset for independent validation (testing) that includes 68 cases.

The response variable’s was uniformely and normally distributed across 3 subsets and 6 gender groups

Training a Regression Deep neural network in h2o

The response variable is DLCO Features include Age, Height and Sex

The structure and mechanism of Deep neural nets have been well described in the previous tutorial. In this study we confront a regression task, so our neural network has only one neuron in the last layer. Our features include a categorical of 2 levels (Sex =Male, Female) and two numerical variables: Age and Height so the first input layer will have 4 neurons (2 for Sex, 1 for each numerical variable).

As the relationship between DLCO and subject’s growth level (Height and Age) is nonlinear, we should consider more than 2 hidden-layers in our Deep network. The more neurons contained in one layer, the more clearly our model can see the data.

Our training configuration consists of:

A feedforward neural network with 2 hidden layers, each one contains 300 neurons, activation function was defined as hyperbolic tangent.

Stopping metric is RMSE, the square root of MSE (model training would be early stopped 3 rounds after RMSE reaches its optimized value).

A 20 folds cross-validation, 100 epochs

An independent model scoring using validation frame

response="DLCO"
features=setdiff(colnames(wtrain),response)

dlmod=h2o.deeplearning  (x = features,
                         y = response,
                         model_id = "DL_Reg",
                         training_frame = wtrain, validation_frame =wvalid,                            
                         nfolds = 20,
                         hidden = c(300,300),
                         stopping_metric = "RMSE",
                         replicate_training_data = TRUE,
                         stopping_tolerance = 0.001,
                         stopping_rounds =3,
                         overwrite_with_best_model=TRUE,
                         epochs=100,
                         activation = "TanhWithDropout",
                         keep_cross_validation_fold_assignment = TRUE,
                         keep_cross_validation_predictions=FALSE,
                         score_each_iteration = TRUE,
                         export_weights_and_biases=TRUE,
                         reproducible = TRUE,seed=1234)
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=                                                                |   1%
  |                                                                       
  |=                                                                |   2%
  |                                                                       
  |==                                                               |   3%
  |                                                                       
  |===========                                                      |  17%
  |                                                                       
  |================================                                 |  50%
  |                                                                       
  |=================================================================| 100%

Model evaluation

dlmod@model$cross_validation_metrics_summary%>%.[,c(1,2)]
##                         mean         sd
## mae                 3.452794 0.35994998
## mse                18.549406  4.2921925
## r2                 0.6814604 0.06280451
## residual_deviance  18.549406  4.2921925
## rmse               4.2519546 0.48491725
## rmsle             0.14345716 0.01502005

The summarised, averaged model scoring showed a good performance:

For a regression supervised machine learning task, h2o provides 6 metrics for measuring the model’s performance: MSE = Mean Squared Error Deviance = residual deviance. If the distribution of outcome is Gaussian, deviance is exactly the same as MSE. RMSE is the square root of MSE, it allows standardizing the MSE MAE: Mean Absolute Error. As with RMSE, the units of MAE are the same as response variable R-squared, or R², also known as the coefficient of determination. This is a traditional metric for evaluating the accuracy of linear regression models and for that reason it’s more robust for reporting your model’s performance in scientific papers. RMSLE: Root Mean Squared Logarithmic Error. This could be more useful than RMSE when considering an under-prediction is worse than an over-prediction.

The history of model cross-validation and independent scoring for deviance, RMSE and MAE could be visualised here:

plot(dlmod,
     timestep = "epochs",
     metric = "rmse")

plot(dlmod,
     timestep = "epochs",
     metric = "mae")

plot(dlmod,
     timestep = "epochs",
     metric = "deviance")

If you are not satisfied with those graphs, here are more advance graphs built in ggplot2:

cvf=dlmod@model$cross_validation_metrics_summary%>%as_tibble()%>%mutate(Metric=rownames(.))%>%gather(cv_1_valid:cv_10_valid,key="Fold",value="Result")

tshf=dlmod@model$scoring_history%>%as_tibble()%>%gather(training_rmse:training_mae,key="Metric",value="Training_Score")

vshf=dlmod@model$scoring_history%>%as_tibble()%>%gather(validation_rmse:validation_mae,key="Metric",value="Validation_Score")



cvf%>%ggplot(aes(x=Fold,y=Result,color=Metric,fill=Metric))+geom_line(group=1,size=1,show.legend = F)+geom_point(shape=21,size=3,color="black",show.legend = F)+theme(axis.text.x=element_blank())+facet_wrap(~Metric,scales="free",ncol=2)+ggtitle("10x10 Cross-validation")

tshf%>%ggplot(aes(x=epochs,y=Training_Score,color=Metric,fill=Metric))+geom_line(group=1,size=1,show.legend = F)+geom_point(shape=21,size=3,color="black",show.legend = F)+facet_wrap(~Metric,scales="free",ncol=1)+ggtitle("Training score history")

vshf%>%ggplot(aes(x=epochs,y=Validation_Score,color=Metric,fill=Metric))+geom_line(group=1,size=1,show.legend = F)+geom_point(shape=21,size=3,color="black",show.legend = F)+facet_wrap(~Metric,scales="free",ncol=1)+ggtitle("Validation score history")

Marginal effect

Though Deep neural nets are blackboxes model, we can visualise the underlying mechanism of our regression model via marginal effect plots:

pdfa=h2o.partialPlot(dlmod,data=wdata,nbins =100,plot=F)%>%.[[2]]%>%as_tibble()
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
pdfh=h2o.partialPlot(dlmod,data=wdata,nbins =100,plot=F)%>%.[[3]]%>%as_tibble()
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
pp1=pdfh%>%ggplot(aes(x=Height,y=mean_response))+geom_point(color="blue3",size=1)+scale_x_continuous(breaks=c(150,160,170,180,190,200))+scale_y_continuous(breaks=c(26,27,28,29,30,31,32,33,34,35,36))
pp2=pdfa%>%ggplot(aes(x=Age,y=mean_response))+geom_point(color="red3",size=1)+scale_x_continuous(breaks=c(20,30,40,50,60,70,80,90))+scale_y_continuous(breaks=c(23,24,25,26,27,28,29,30,31,32,33,34,35,36))

grid.arrange(pp1,pp2)

External validation on an independent subset

h2o.performance(dlmod,wtest)
## H2ORegressionMetrics: deeplearning
## 
## MSE:  20.40886
## RMSE:  4.517617
## MAE:  3.772528
## RMSLE:  0.153166
## Mean Residual Deviance :  20.40886
h2o.performance(dlmod,wdata)
## H2ORegressionMetrics: deeplearning
## 
## MSE:  20.07843
## RMSE:  4.480896
## MAE:  3.597462
## RMSLE:  0.1510107
## Mean Residual Deviance :  20.07843

When applied to either an independent test set or the pooled sample, our DL regression model always shows a good performance

Qualitative validation

Now we will evaluate the accuracy of our predictive model in a qualitative way.

First, we want to know whether our model could overestimate or understimate the DLCO in healthy subjects ? Then we would like to know whether when using our model, the healthy subjects could be misclassified as “abnormal”

The measured lung function values are usually interpreted as percentage of this mean predicted value. Patients are classified as abnormal if their value drop below 80% or 70% of predicted value. For clinical practice, we are also interested in the lower and upper bounds of the normal range. These special percentiles allow for determination of the standardized score (Z-Score) that should be used for interpretation.

Z-Score is defined as the deviation from the mean in standard deviation units (SD). In pulmonary function testing, the Z-score could be determined by a simple rule using Residual Standard Deviation (RSD). This method is straight-forward but only correct when the parameter is normally distributed within the sample based on which regression model is built. The predictive model gives us values of centered mean and an error term (residual SD or SE) for an approximated Gaussian distribution. Then the Z-Score can be determined as:

Z-Score = (Observed – Mean predicted)/Residual standard deviation

pred=h2o.predict(dlmod,newdata=wdata)%>%as_tibble()%>%mutate(truth=data$DLCO,
                                                             Sex=data$Sex,
                                                             Height=data$Height,
                                                             Age=data$Age,
                                                             Error=(.$predict-data$DLCO),
                                                             PC=100*(data$DLCO/.$predict),
                                                             Tendency=ifelse(.$predict>data$DLCO,"Overestimated", "Underestimated"))
## 
  |                                                                       
  |                                                                 |   0%
  |                                                                       
  |=================================================================| 100%
rsd=sd(pred$Error)

pred$LLN=pred$predict-1.645*rsd

pred$ULN=pred$predict+1.645*rsd


pred$Z_score=(pred$truth-pred$predict)/rsd

pred=pred%>%mutate(.,ToleranceZscore=ifelse(Z_score<(-1.645)|Z_score>1.645,"Unacceptable", "Acceptable"))
pred%>%ggplot(aes(x=Age))+geom_point(aes(y=truth),color="grey80")+geom_point(aes(y=predict,color=Tendency))+facet_wrap(~Sex)+scale_color_brewer(palette = "Set1")

pred%>%ggplot(aes(x=Height))+geom_point(aes(y=truth),color="grey80")+geom_point(aes(y=predict,color=Tendency))+facet_wrap(~Sex)+scale_color_brewer(palette = "Set1")

pred%>%ggplot()+geom_density(aes(x=truth),alpha=0.5,fill="skyblue")+geom_density(aes(x=predict),alpha=0.5,fill="red")+facet_wrap(~Sex)

pred%>%ggplot()+geom_density(aes(x=Error,fill=Sex),alpha=0.5)+scale_fill_brewer(palette = "Set1")

Our model has a tendency to overestimate the DLCo in both females and males. However the prediction was almost in good agreement with the true observed values. It seems that our model has also captured succesffuly the nonlinear relationship between DLCO and Age.

pred%>%ggplot()+geom_point(aes(x=Z_score,y=PC,color=Tendency))+geom_hline(yintercept=c(80,100),linetype=2)+geom_vline(xintercept=c(-1.645,1.645),linetype=2)

pred%>%ggplot(aes(x=Age))+geom_point(aes(y=truth),color="grey80")+geom_point(aes(y=predict,color=ToleranceZscore))+geom_smooth(aes(y=predict),se=F,color="black")+geom_smooth(aes(y=ULN),se=F,color="blue4",linetype=2)+geom_smooth(aes(y=LLN),se=F,color="red4",linetype=2)+facet_wrap(~Sex,scales="free")+scale_color_brewer(palette = "Set1",direction = -1)

pred%>%ggplot(aes(x=Height))+geom_point(aes(y=truth),color="grey80")+geom_point(aes(y=predict,color=ToleranceZscore))+geom_smooth(aes(y=predict),se=F,color="black")+geom_smooth(aes(y=ULN),se=F,color="blue4",linetype=2)+geom_smooth(aes(y=LLN),se=F,color="red4",linetype=2)+facet_wrap(~Sex,scales="free")+scale_color_brewer(palette = "Set1",direction = -1)

Those graphs showed the agreement level between the predicted values and true DLCO values in our sample, as well as the mean predicted, upper and lower limits of normal.

We presumed that it would be unacceptable if a healthy individu is misclassified as “abnormal” based on Z-score (if their z-score is inferior than -1.645 or supperior than 1.645)

As we could see, the proportion of unacceptable classfications was relatively small, most of healthy subjects were correctly classified as “having Normal DLCO”

Conclusion

As we could see, the Deep learning is also very powerful for the Regression tasks. There are 4 good reasons for which Deep learning could be adopted for predicting our lung function parameters:

  1. We don’t need an explicit regression model that could be interpreted, our goal is the accuracy of prediction and we could be assured about that using sophisticated algorithms

  2. For daily practice, predicting the normal value range is usually done by computer, not from manual calculation. A neural network model could be easily implemented into the software.

  3. Most of existing reference value sets failed to capture the non-linear relationship between lung function and Age, due to the limitation of generalized linear regression method. A neural network model can handle easily the non-linearity problem.

  4. Manual development of predictive model could take very long time; the searchers must wait for years to collect enough data from healthy peoples. Machine learning based workflow can accelerate significantly this process. Every day, hundred to thousand lung function measurements are being performed in any country, and many tested subjects are normal. If those data could be automatically translated by computers to predictive model, we can update daily, even in real time reference values for our patient.

See you soon in the next tutorial and thank for joining us.

END