Foreword: About the Machine Learning in Medicine (MLM) project
The MLM project has been initialized in 2016 and aims to:
1.Encourage using Machine Learning techniques in medical research in Vietnam
2.Promote the use of R statistical programming language, an open source and leading tool for practicing data science.
Background
In this 5th case study, we introduce for the first time the combined Machine learning (ML) and Proportional Hazard Regression algorithms to evaluate the impact of risk factors on the survival time of patients with a Cardiovascular disease. Our goal will be extending the Cox proportional hazard (CoxPH) model with ML techniques, including:
Feature selection using a Random-forest method, Cox-PH model with permutation and univariate scores
A benchmark study based on Cross-validation for comparing the performance of 6 different Survival analysing algorithms.
Bootstraping a Cox-PH model and extracting the confidence interval of Hazard ratios
Discovering the Boosted Cox-Ph model
Plotting the marginal effects of predictors on the Risk of mortality
Materials and method
The original dataset was provided in Chaper 4.3, Machine Learning in Medicine-Cookbook 1 (T. J. Cleophas and A. H. Zwinderman, SpringerBriefs in Statistics 2014). You can get the original data in SPSS SAV format Coxoutcomeprediction.sav) from their website: extras.springer.com.
The study included 60 patients aged from 45 to 91 yrs old. The patients were followed up over 30 months until their death or loss to follow-up contact (censoring). Two potential risk factors were recorded: age and treatment modality (with or without).
Data analysis was performed in R statistical programming language. The mlr package (https://mlr-org.github.io/mlr-tutorial/release/html/index.html) will be used. This is the only ML framework that supports Survival analysis.
First, following packages must be established in R-studio (some other packages might also be required during the analysis), and we will personalize the aesthetic effects for ggplot2
library(foreign)
library(tidyverse)
library(survival)
library(mlr)
library(gridExtra)
library(ggfortify)
library(CoxBoost)
my_theme <- function(base_size =9, base_family = "sans"){
theme_minimal(base_size = base_size, base_family = base_family) +
theme(
axis.text = element_text(size=8),
axis.text.x = element_text(angle = 0, vjust = 0.5, hjust = 0.5),
axis.title = element_text(size =8),
panel.grid.major = element_line(color = "gray"),
panel.grid.minor = element_blank(),
panel.background = element_rect(fill = "#ffffff"),
strip.background = element_rect(fill = "black", color = "black", size =0.5),
strip.text = element_text(face = "bold", size = 8, color = "white"),
legend.position = "bottom",
legend.justification = "center",
legend.background = element_blank(),
panel.border = element_rect(color = "grey5", fill = NA, size = 0.5)
)
}
theme_set(my_theme())
mycolors=c("darkred","black","#004431","#000c44","#2d1600","purple","#202302")
myfillcolors=c("#d10c0c","grey10","#1e6651","#213989","#5b431a","purple","#393f03")
Results
Step 0: Data loading and reshape
require(foreign)
df=read.spss("Coxoutcomeprediction.sav",to.data.frame = T,use.value.labels = TRUE)%>%as_tibble()%>%mutate(.,Event=factor(.$event,levels=c(1,0),labels=c("Occured","Censored")),event=(.$event==1))
df$treatment=factor(df$treatment,levels=c(1,0),labels=c("Yes","No"))
df
## # A tibble: 60 × 5
## followupmonth event treatment age Event
## <dbl> <lgl> <fctr> <dbl> <fctr>
## 1 1 TRUE No 65 Occured
## 2 1 TRUE No 66 Occured
## 3 2 TRUE No 73 Occured
## 4 2 TRUE No 91 Occured
## 5 2 TRUE No 86 Occured
## 6 2 TRUE No 87 Occured
## 7 2 TRUE No 54 Occured
## 8 2 TRUE No 66 Occured
## 9 2 TRUE No 64 Occured
## 10 3 FALSE No 62 Censored
## # ... with 50 more rows
Once the above commmands executed, we will get a clean dataframe (n=50), that will be used throughout our analysis.
Step 1) Data exploration
Figure 1: Descriptive analysis
df%>%ggplot(aes(x=age,y=followupmonth,color=treatment,fill=treatment))+geom_jitter(shape=21,size=5,alpha=0.6)+facet_wrap(~Event)+scale_color_manual(values=mycolors)+scale_fill_manual(values=myfillcolors)
df%>%ggplot(aes(x=treatment,y=followupmonth,color=treatment,fill=treatment))+stat_summary(fun.ymin=min,fun.ymax=max,fun.y="median",shape=21,size=1)+facet_wrap(~Event)+scale_color_manual(values=mycolors)+scale_fill_manual(values=myfillcolors)+coord_flip()
Those figures indicate that without treatment, the patient’s survival time was significantly reduced. The patient’s age might also attribute to the mortality risk though its impact is still not clear.
psych::describeBy(df$followupmonth,group=df$treatment)
## $Yes
## vars n mean sd median trimmed mad min max range skew kurtosis
## X1 1 30 25.37 5.04 28 25.92 2.97 16 30 14 -0.59 -1.28
## se
## X1 0.92
##
## $No
## vars n mean sd median trimmed mad min max range skew kurtosis se
## X1 1 30 12.1 11.14 8 11.21 8.9 1 30 29 0.68 -1.18 2.03
##
## attr(,"call")
## by.default(data = x, INDICES = group, FUN = describe, type = type)
The descriptive analysis showed that median survival time was different between two treatment modalities (28 vs 8 months). It could be foreseen that treatment is important predictor in our model.
Figure 2: Kaplan-Meier curve for Death events in patients with and without treatment
library(survival)
fit=survfit(Surv(followupmonth,event)~treatment,data=df)
summary(fit)
## Call: survfit(formula = Surv(followupmonth, event) ~ treatment, data = df)
##
## treatment=Yes
## time n.risk n.event survival std.err lower 95% CI upper 95% CI
## 16 30 2 0.933 0.0455 0.848 1.000
## 17 28 1 0.900 0.0548 0.799 1.000
## 18 27 1 0.867 0.0621 0.753 0.997
## 19 26 2 0.800 0.0730 0.669 0.957
## 20 24 1 0.767 0.0772 0.629 0.934
## 21 23 1 0.733 0.0807 0.591 0.910
## 22 22 1 0.700 0.0837 0.554 0.885
## 23 20 1 0.665 0.0865 0.515 0.858
## 24 19 1 0.630 0.0887 0.478 0.830
## 26 18 1 0.595 0.0905 0.442 0.802
## 27 17 1 0.560 0.0917 0.406 0.772
## 28 16 2 0.490 0.0926 0.338 0.710
## 29 13 3 0.377 0.0914 0.234 0.606
##
## treatment=No
## time n.risk n.event survival std.err lower 95% CI upper 95% CI
## 1 30 2 0.933 0.0455 0.848 1.000
## 2 28 7 0.700 0.0837 0.554 0.885
## 4 20 1 0.665 0.0865 0.515 0.858
## 5 19 1 0.630 0.0887 0.478 0.830
## 6 18 2 0.560 0.0917 0.406 0.772
## 7 16 1 0.525 0.0924 0.372 0.741
## 9 15 2 0.455 0.0924 0.306 0.677
## 11 13 1 0.420 0.0917 0.274 0.644
## 12 12 1 0.385 0.0905 0.243 0.610
## 14 11 1 0.350 0.0887 0.213 0.575
## 16 10 1 0.315 0.0865 0.184 0.540
## 17 9 1 0.280 0.0837 0.156 0.503
## 18 8 1 0.245 0.0802 0.129 0.465
library(ggfortify)
autoplot(fit)+scale_color_manual(values=mycolors)+scale_fill_manual(values=myfillcolors)
The figure indicates that the patients who did not receive their treatment have higher mortality risk than the treated patients
Step 2: Features exploration
Figure 3: Attribution of each predictor to the predicted mortality risk
library(mlr)
surv.task=makeSurvTask(id="Survival",data=df[,-5],target=c("followupmonth","event"),censoring="rcens")
rfi=generateFilterValuesData(surv.task,method=c("rf.importance"))%>%.$data%>%ggplot(aes(x=reorder(name,rf.importance),y=rf.importance,fill=reorder(name,rf.importance)))+geom_bar(alpha=0.8,stat="identity",color="black",show.legend=F)+scale_fill_manual(values=myfillcolors)+scale_x_discrete("Features")+coord_flip()
rmd=generateFilterValuesData(surv.task,method=c("rf.min.depth"))%>%.$data%>%ggplot(aes(x=reorder(name,rf.min.depth),y=rf.min.depth,fill=reorder(name,rf.min.depth)))+geom_bar(alpha=0.8,stat="identity",color="black",show.legend=F)+scale_fill_manual(values=myfillcolors)+scale_x_discrete("Features")+coord_flip()
uv=generateFilterValuesData(surv.task,imp.learner=makeLearner("surv.coxph"),method="univariate.model.score")%>%.$data%>%ggplot(aes(x=reorder(name,univariate.model.score),y=univariate.model.score,fill=reorder(name,univariate.model.score)))+geom_bar(alpha=0.8,stat="identity",color="black",show.legend=F)+scale_fill_manual(values=myfillcolors)+scale_x_discrete("Features")+coord_flip()
pi=generateFilterValuesData(surv.task,imp.learner=makeLearner("surv.coxph"),method="permutation.importance")%>%.$data%>%ggplot(aes(x=reorder(name,permutation.importance),y=permutation.importance,fill=reorder(name,permutation.importance)))+geom_bar(alpha=0.8,stat="identity",color="black",show.legend=F)+scale_fill_manual(values=myfillcolors)+scale_x_discrete("Features")+coord_flip()
grid.arrange(rfi,rmd,uv,pi)
The results confirm that Treatment modality is the most important risk factor in our model. However, patient’s age also contributes to the mortality risk (based on the univariate model and permutation importance scores). Thus, both features will be included in our model.
Step 3: Benchmark study - Cox-Ph model against 5 other survival analysing algorithm
#Benchmark study : comparing 7 different algorithms for survival task
# Creating a list of 7 learners
svlearners=list(
makeLearner("surv.coxph"),
makeLearner("surv.cv.CoxBoost"),
makeLearner("surv.penalized.lasso"),
makeLearner("surv.penalized.ridge"),
makeLearner("surv.glmnet"),
makeLearner("surv.randomForestSRC")
)
rdesc=makeResampleDesc("RepCV",folds=4L,reps=15L)
# Initalising the Benchmark study
set.seed(123)
svbnmrk=benchmark(svlearners,surv.task,rdesc)
bmrkdata=getBMRPerformances(svbnmrk, as.df = TRUE)%>%.[,-1]
Figure 4A,B,C,D: Distribution of C-index of 6 survival algorithms, based on a 4x15 cross-validation
plotBMRRanksAsBarChart(svbnmrk)+scale_fill_manual(values=myfillcolors)
plotBMRBoxplots(svbnmrk)+aes(fill=learner.id)+coord_flip()+scale_fill_manual(values=myfillcolors,name="Learners")
ggplot(bmrkdata)+geom_path(aes(x=iter,y=cindex,color=learner.id),size=1,alpha=0.8)+facet_wrap(~learner.id,scales="free",ncol=2)+scale_color_manual(values=myfillcolors)
ggplot(bmrkdata)+geom_histogram(color="black",aes(x=cindex,fill=learner.id),alpha=0.7)+facet_wrap(~learner.id,scales="free",ncol=2)+scale_color_manual(values=myfillcolors)
Then we perform a pairwised comparision among the algorithms, based on Nemenyi test:
kruskal.test(data=bmrkdata,cindex~learner.id)
##
## Kruskal-Wallis rank sum test
##
## data: cindex by learner.id
## Kruskal-Wallis chi-squared = 254.5, df = 5, p-value < 2.2e-16
PMCMR::posthoc.kruskal.nemenyi.test(x=bmrkdata$cindex,g=bmrkdata$learner.id,method="Tukey")
##
## Pairwise comparisons using Tukey and Kramer (Nemenyi) test
## with Tukey-Dist approximation for independent samples
##
## data: bmrkdata$cindex and bmrkdata$learner.id
##
## surv.coxph surv.cv.CoxBoost surv.penalized.lasso
## surv.cv.CoxBoost 1.000 - -
## surv.penalized.lasso 5.7e-14 5.9e-14 -
## surv.penalized.ridge 5.7e-14 5.9e-14 1.000
## surv.glmnet 1.000 1.000 5.8e-14
## surv.randomForestSRC 0.013 0.026 2.5e-11
## surv.penalized.ridge surv.glmnet
## surv.cv.CoxBoost - -
## surv.penalized.lasso - -
## surv.penalized.ridge - -
## surv.glmnet 5.8e-14 -
## surv.randomForestSRC 2.5e-11 0.014
##
## P value adjustment method: none
Those results suggest that Cox-PH model and CoxBoost model are the two best algorithms for predicting the mortality risk of our patients. Those two models are equivalent and perform even better than other sophosticated models such as randomForest SRC, penalized ridge or lasso.
Step 4: Bootstraping a Cox-ph model
First, we fit a classic Cox-ph model in mlr (you can also do this using Survival package)
coxphlearner= makeLearner("surv.coxph")
coxphmod=train(coxphlearner,surv.task)
coxph=train(coxphlearner,surv.task)%>%.$learner.model
summary(coxph)
## Call:
## survival::coxph(formula = f, data = data)
##
## n= 60, number of events= 40
##
## coef exp(coef) se(coef) z Pr(>|z|)
## treatmentNo 0.80706 2.24131 0.33024 2.444 0.0145 *
## age 0.02888 1.02930 0.01198 2.411 0.0159 *
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## exp(coef) exp(-coef) lower .95 upper .95
## treatmentNo 2.241 0.4462 1.173 4.282
## age 1.029 0.9715 1.005 1.054
##
## Concordance= 0.711 (se = 0.051 )
## Rsquare= 0.213 (max possible= 0.992 )
## Likelihood ratio test= 14.4 on 2 df, p=0.0007482
## Wald test = 13.94 on 2 df, p=0.0009417
## Score (logrank) test = 14.77 on 2 df, p=0.0006191
Now we will apply bootstraping process on our model:
bootrs=makeResampleDesc("Bootstrap",iters=100L)
r=resample(coxphlearner,surv.task,bootrs,models=TRUE) # Bootstraping
r$aggr
## cindex.test.mean
## 0.6949353
The averaged C-index is 0.695, based on 100 bootstrap iterations
Confidence intervals of Hazard-ratios
resdf=r$measures.test%>%as_tibble()%>%mutate(HR_NOtreatment=rep(0,nrow(.)),HR_Age=rep(0,nrow(.)),R2=rep(0,nrow(.)),LLRtest=rep(0,nrow(.)),Waldtest=rep(0,nrow(.)),LogRank=rep(0,nrow(.)),p_notmt=rep(0,nrow(.)),p_Age=rep(0,nrow(.)))
for(i in 1:nrow(resdf)){
coef=r%>%.$models%>%.[[i]]%>%.$learner.model%>%.$coefficients
aic=r%>%.$models%>%.[[i]]%>%.$learner.model%>%.$aic
rsq=r%>%.$models%>%.[[i]]%>%.$learner.model%>%summary(.)%>%.$rsq%>%.[1]
llr=r%>%.$models%>%.[[i]]%>%.$learner.model%>%summary(.)%>%.$logtest%>%.[3]
w=r%>%.$models%>%.[[i]]%>%.$learner.model%>%summary(.)%>%.$waldtest%>%.[3]
scr=r%>%.$models%>%.[[i]]%>%.$learner.model%>%summary(.)%>%.$sctest%>%.[3]
p1=r%>%.$models%>%.[[i]]%>%.$learner.model%>%summary(.)%>%.$coefficients%>%.[9]
p2=r%>%.$models%>%.[[i]]%>%.$learner.model%>%summary(.)%>%.$coefficients%>%.[10]
resdf$HR_NOtreatment[i]=exp(coef[1])
resdf$HR_Age[i]=exp(coef[2])
resdf$p_Age[i]=p2
resdf$p_notmt[i]=p1
resdf$R2[i]=rsq
resdf$LLRtest[i]=llr
resdf$Waldtest[i]=w
resdf$LogRank[i]=scr
}
HR_NOtreatment=quantile(resdf$HR_NOtreatment,c(.025, .5, .975))
HR_Age=quantile(resdf$HR_Age,c(.025, .5, .975))
rbind(HR_NOtreatment,HR_Age)%>%knitr::kable()
2.5% | 50% | 97.5% | |
---|---|---|---|
HR_NOtreatment | 0.8934282 | 2.253446 | 5.695903 |
HR_Age | 1.0118653 | 1.029679 | 1.065440 |
Figure 5A,B: confidence intervals of Log-likelihood ratio test, LogRank test, Wald test and predictor’s t-test p values
Figure 5C,D: Confidence interval of Hazard ratios for Age and NOTreatment
Figure 5E: Confidence interval of Rsquared coefficient and C-index values
resdf%>%gather(LLRtest:p_Age,key="Tests",value="p_value")%>%ggplot(aes(x=p_value,fill=Tests))+geom_histogram(color="black",alpha=0.7)+scale_fill_manual(values=myfillcolors)+facet_wrap(~Tests,ncol=2,scales="free")+geom_vline(xintercept = 0.05,color="red4",linetype="dotted",size=1)
resdf%>%gather(LLRtest:p_Age,key="Tests",value="p_value")%>%ggplot(aes(x=iter,y=p_value,color=Tests))+geom_path(alpha=0.8,size=1)+scale_color_manual(values=myfillcolors)+facet_wrap(~Tests,ncol=2,scales="free")+geom_hline(yintercept = 0.05,color="red4",linetype="dotted",size=1)
resdf%>%gather(HR_NOtreatment:HR_Age,key="Predictor",value="Hazard_ratio")%>%ggplot(aes(x=Hazard_ratio,fill=Predictor))+geom_histogram(color="black",alpha=0.7)+scale_fill_manual(values=myfillcolors)+facet_wrap(~Predictor,ncol=1,scales="free")+geom_vline(xintercept=1,color="red4",linetype="dotted",size=1)
resdf%>%gather(HR_NOtreatment:HR_Age,key="Predictor",value="Hazard_ratio")%>%ggplot(aes(x=iter,y=Hazard_ratio,color=Predictor))+geom_path(alpha=0.8,size=1)+scale_color_manual(values=myfillcolors)+facet_wrap(~Predictor,ncol=1,scales="free")+geom_hline(yintercept=1,color="red4",linetype="dotted",size=1)
resdf%>%gather(cindex,R2,key="Indices",value="Value")%>%ggplot(aes(x=Value,fill=Indices))+geom_histogram(color="black",alpha=0.7)+scale_fill_manual(values=myfillcolors)+facet_grid(Indices~.,scales="free")+geom_vline(xintercept = 0.5,color="red4",linetype="dotted",size=1)
Note: Though the interpretation of Cox-ph model looks similar to that of Logistic model (both of them are based on a Logarithmic link function), proportional hazards regression differs from logistic regression by assessing a rate instead of a proportion. The Cox-ph model predicts the incidence or hazard rate, or number of new event, occured within a sample at-risk per unit time whilst the Logistic model treats the proportion of new event that develop during a given time period or the cumulative incidence. Logistic regression provides the odds ratio while proportional hazards regression (Cox-model) provides the hazard ratio.
Step 5: Discovering the Boosted Cox model
As mentioned above, the CoxBoost model is considered equivalent to our Cox-PH model. We will explore how this algorithm works on our data;
The CoxBoost packae https://cran.r-project.org/web/packages/CoxBoost/CoxBoost.pdf provides routines for fitting Cox models by likelihood based boosting.
The output of Coxboost algorithm contains following items:
n, p = number of observations and number of covariates.
stepno = number of boosting steps
xnames = vector of length p containing the names of the covariates
coxbstlearner= makeLearner("surv.cv.CoxBoost")
coxbst=train(coxbstlearner,surv.task)
coxbstfit=coxbst$learner.model
cat("Cox-PH Model boosting was performed on",coxbstfit$n,"observations","and",coxbstfit$p,"Covariates:",coxbstfit$xnames,".Number of boosting step was",coxbstfit$stepno)
## Cox-PH Model boosting was performed on 60 observations and 3 Covariates: age treatment.Yes treatment.No .Number of boosting step was 33
scoremat contains the value of the score statistic for each of the optional covariates before each boosting step.
Figure 6: Exploring the Score matrix
scmatrx=coxbstfit$scoremat%>%as_tibble()
coefAge=coxbstfit$coefficients[,1]
coefTrYes=coxbstfit$coefficients[,2]
hr1=exp(coxbstfit$coefficients[,1])
hr2=exp(coxbstfit$coefficients[,2])
covmatrx=cbind(coefAge,coefTrYes)%>%as_tibble()%>%mutate(.,hrAge=exp(coefAge),hrTrmtYes=exp(-coefTrYes),iter=c(1:nrow(.)))
covmatrx%>%gather(coefAge:hrTrmtYes,key="Covar",value="Estimate")%>%ggplot(aes(x=iter,y=Estimate,color=Covar))+geom_path(size=1)+geom_vline(xintercept =coxbstfit$stepno,color="blue",linetype="dotted",size=1)+scale_color_manual(values=myfillcolors)+facet_wrap(~Covar,ncol=2,scales="free")
Figure 7: Exploring the Linear predictors mand Lambda matrices
linear.predictors is a matrix, giving the linear predictor for boosting steps 0 to stepno and every observation.
Lambda is the baseline hazard function at time point Ti. This is a matrix with the Breslow estimate for the cumulative baseline hazard for boosting steps 0 to stepno for every event time.
lpred=coxbstfit$linear.predictor%>%as_tibble()%>%mutate(.,iter=c(1:nrow(.)))%>%gather(V1:V60,key="Id",value="linearPredictor")
lpred%>%ggplot(aes(x=iter,y=linearPredictor,color=Id,fill=Id))+geom_path(show.legend = F)+geom_vline(xintercept=coxbstfit$stepno,color="red4",linetype="dotted",size=1)
lambdadf=coxbstfit$Lambda%>%as_tibble()%>%mutate(.,iter=c(1:nrow(.)))%>%gather(V1:V23,key="Event_time",value="Breslow_estimate")
lambdadf%>%ggplot(aes(x=iter,y=Breslow_estimate,color=Event_time))+geom_path(size=1,show.legend = F)+geom_vline(xintercept =coxbstfit$stepno,color="red4",linetype="dotted",size=1)+facet_wrap(~Event_time,ncol=5,scales="free")
predCox=(predict(coxphmod,surv.task))%>%as_tibble()%>%mutate(.,Model="CoxPH",age=df$age,treatment=df$treatment,Event=df$Event)
linearPredictor=predCox$response
Risk=predict(coxph,newdata=df,newtime=df$followingupmonth,type="risk")
predCox=cbind(predCox,linearPredictor,Risk)
predCoxBst=(predict(coxbst,surv.task))%>%as_tibble()%>%mutate(.,Model="CoxBoost",age=df$age,treatment=df$treatment,Event=df$Event)
linearPredictor=predCoxBst$response
Risk=predict(coxbstfit,newdata=df,type="risk")
predCoxBst=cbind(predCoxBst,linearPredictor,Risk)
predlong=rbind(predCox[,c(1:9)],predCoxBst[,c(1:9)])
Figure 8: Comparing the prediction of two models
predlong%>%ggplot(aes(x=age,y=linearPredictor,color=Model,fill=Model))+geom_smooth(se=T,alpha=0.2)+scale_color_manual(values=c("red4","blue4"))+scale_fill_manual(values=c("red","blue"))
predlong%>%ggplot(aes(x=age,y=linearPredictor,color=Model,fill=Model))+geom_smooth(se=T,alpha=0.2)+facet_wrap(~Event,ncol=2,scales="free")+scale_color_manual(values=c("red4","blue4"))+scale_fill_manual(values=c("red","blue"))
predlong%>%ggplot(aes(x=as.numeric(treatment),y=linearPredictor,color=Model,fill=Model))+geom_smooth(se=T,alpha=0.2)+facet_wrap(~Event,ncol=2,scales="free")+scale_color_manual(values=c("red4","blue4"))+scale_fill_manual(values=c("red","blue"))
predlong%>%ggplot(aes(x=linearPredictor,color=Model))+geom_freqpoly(alpha=0.8,size=1)+facet_wrap(~Event,ncol=1,scales="free")+scale_color_manual(values=c("red","blue"))
As we can see, though the CoxBoost model is based on Boosting algorithm which is more complicated than the simple Cox-ph model, the predictions by those two models are equivalent. Both models could provide information about the overal impact of Age and Treatment on the mortality risk.
predCox%>%gather(linearPredictor,Risk,key="Function",value="Predicted")%>%ggplot(aes(x=age,y=Predicted,color=treatment,fill=treatment))+geom_smooth(se=T,alpha=0.2)+facet_wrap(~Function,ncol=2,scales="free")+scale_color_manual(values=c("red4","blue4"))+scale_fill_manual(values=c("red","blue"))
predCox%>%gather(linearPredictor,Risk,key="Function",value="Predicted")%>%ggplot(aes(x=as.integer(treatment),y=Predicted,color=Function,fill=Function))+geom_smooth(se=T,alpha=0.2)+facet_wrap(~Function,ncol=2,scales="free")+scale_color_manual(values=c("red4","blue4"))+scale_fill_manual(values=c("red","blue"))
However, if we want to obtain more information about the impact of risk factor, such as Age or treatment, at a defined Time point (for example: at 10th month), Coxboost model can do better:
predCoxBst%>%gather(10:25,key="Time",value="Risk")%>%ggplot(aes(x=age,y=Risk,color=Time,fill=Time))+geom_smooth(se=T,alpha=0.2,show.legend = F)+facet_wrap(~Time,ncol=4,scales="free")
predCoxBst%>%gather(10:25,key="Time",value="Risk")%>%ggplot(aes(x=as.integer(treatment),y=Risk,color=Time,fill=Time))+geom_smooth(se=T,alpha=0.2,show.legend = F)+facet_wrap(~Time,ncol=4,scales="free")+scale_x_continuous(breaks=c(1,2),"Treatment (1=Yes,2=No)")
Step 6: Plotting Marginal effects
The mlr package allows to display the marginalized impact of each factor on the mortality risk as follows:
Figure 9: Marginal effect of Age as a risk factor in CoxPH and CoxBoost models
pd1=generatePartialDependenceData(coxbst,surv.task,features=c("age","treatment"),interaction=T,individual=T,gridsize = 60L)%>%.$data%>%as_tibble()%>%mutate(.,Model="CoxBoost")
pd2=generatePartialDependenceData(coxphmod,surv.task,features=c("age","treatment"),interaction=T,individual=T,gridsize = 60L)%>%.$data%>%as_tibble()%>%mutate(.,Model="CoxPH")
pd=rbind(pd1,pd2)
pd%>%ggplot(aes(x=age,y=Risk,color=Model,fill=Model))+geom_smooth(se=T,alpha=0.5)+facet_wrap(~treatment,ncol=2,scales="free")+scale_color_manual(values=c("red4","blue4"))+scale_fill_manual(values=c("red","blue"))
Figure 10: Marginal effect of Treatment as a risk factor in CoxPH and CoxBoost models
pd%>%ggplot(aes(x=treatment,y=Risk,color=Model,fill=Model))+stat_summary(fun.ymin=min,fun.ymax=max,fun.y="median",shape=21,size=1)+scale_color_manual(values=c("darkred","darkblue"))+scale_fill_manual(values=c("red","skyblue"))
Conclusion: We have succesfully extended the traditional Proportional hazard regression (Cox-PH) analysis with machine learning techniques. A Survival analysis could be easily handled as other supervised learning tasks in mlr, by introducing two target outcomes: time to event and event status (in logical). Though the mlr package supports 15 different algorithms for Survival analysis and some of them based on complicated mechanism, the classical algorithms such as Cox-PH model are still sufficiently good for interpretive studies. The Boosted Cox-model is an alternative solution and it works as well as the simple Cox-ph model.