1 Setup

To try some predictive modelling, we need to combine the output of CATIE TRS : Survival Analysis - v3.0 with some of the source CATIE data for neurocognitive testing, and the “raw” data for PANSS and baseline demographics.

This is acheived using a script ML_pre-process.R that I am happy to provide on request – it is not included here because it makes use of source data from the CATIE trial that I do not have permission to publish (it is available from the NIMH data archive on application) and therefore, can’t be executed “stand alone”. As a result, only the derived variables actually required for this analyses are loaded from ML_data_whole_population.csv, so this notebook is “self-contained” and can be reproduced.

Note, we are not making strong claims for this attempt – it represents initial work on predictive modelling in this data set, where we have established that - even for classical inferential analyses - there’s not a great deal going on. My view is that as the inferential result shows nothing (with the benefit of the GLMs having access to the entire data set) I would be surprised if any variables (or combination) proved to have value in a predictive model. Nevertheless, boosting tree algorithms have (from experience) shown to be a good method to try on difficult datasets. We tried regularised GLMs, and tried SMOTE-like imputing to compensate for the rarity of TRS cases – and obtained similarly (poor) results, so this notebook really just illustrates what was tried for completeness. If we had obtained more robust results, the following would need to be subjected to some validation beyond the simple permutation test (under the null) for something more informative. Finally, this notebook is written in a fairly tutorial tone / level of detail, to make the logic of the analyses as transparent as possible.

rm( list = ls() )

require( ggplot2 )
require( gridExtra )
require( xgboost )
require( reshape2 )
require( kableExtra )
require( RColorBrewer )
require( caret )

# load data
d <- read.csv("ML_data_whole_population.csv")

# we have to use only complete cases
d <- d[ which( complete.cases(d) ), ]
N <- nrow( d )

# relabel classes 0 = NTR, 1 = TRS
d$status.TRS <- ifelse( d$status.TRS == 2, 1, 0)

# re-order so that status.TRS is the last column
temp <- d[ , -which( names(d) %in% c("status.TRS") ) ]
d <- cbind( temp, d$status.TRS )
names(d)[length(d)] <- "status.TRS"

2 Initial Tests

The data in d contains the “whole population” dataset. Because of pseduo-randomisation, we’ll get different values each time the following code is run:

To begin with, we’ll use a random half split of the data:

# independent (predictor) variables
x.vars <- names( d )[1:(length(names(d))-1)]
# dependent (outcome, response) variables
y.var  <- names( d )[length(names(d))]

folds <- createFolds(d$status.TRS, k = 2)

# divide data into matrix of X and vector of Y for convenience
# and convert to matrices for xgboost
trainSet.X <- as.matrix( d[ folds$Fold1, x.vars  ] )
trainSet.Y <- as.matrix( d[ folds$Fold1, y.var  ] )

validSet.X <- as.matrix( d[ folds$Fold2, x.vars ] )
validSet.Y <- as.matrix( d[ folds$Fold2, y.var ] )

Look at proportions of each split on group just to check

table( trainSet.Y )
## trainSet.Y
##   0   1 
## 518  19
table( validSet.Y )
## validSet.Y
##   0   1 
## 502  35

… and they are reasonably well balanced

Next, fit the xgboost, using trees and the rest of the parameters as the defaults.

fit <- xgboost(data = trainSet.X, 
               label = trainSet.Y, 
               eta = 0.3, nthread = 2, nrounds = 2, objective = "binary:logistic",
               verbose = 2)

Predict the classifications of the validation set, using model fitted on the training set:

pred <- predict(fit, validSet.X)

decRule <- function( x ) { return( ifelse( x > 0.5, 1, 0 )) }
y_pred  <- decRule( pred )

And compute a confusion matrix:

cm <- confusionMatrix( y_pred, validSet.Y, positive = "1" )

And we see that :

(cm$table)
##           Reference
## Prediction   0   1
##          0 494  35
##          1   8   0
  1. True positives (reference group 1, correctly predicted as group 1) = 0
  2. True negatives (reference = group 0, correctly predicted as group 0) = 494
  3. False positives (for group 0, incorrectly predicted as group 1) = 8
  4. False negatives (for group 1, incorrectly predicted as group 0) = 35

This corresponds to sensitivity and specificity:

  • Sensitivity = 0
  • Specificity = 0.98

3 Transparency of the model

Unlike GLMs, the results of tree-based are harder to interpret from the perspective of classicial inference. So we’ll exploit the tree plotting functions in the xgboost package:

importance_matrix <- xgb.importance(model = fit)
print(importance_matrix)
##             Feature        Gain      Cover Frequency
## 1:                N 0.408286536 0.40355727 0.2222222
## 2:          speed_s 0.392070541 0.00769494 0.2222222
## 3:           pos_p3 0.152688728 0.01299504 0.1111111
## 4:                G 0.018429595 0.18842870 0.1111111
## 5:       B.personal 0.017830713 0.19023293 0.1111111
## 6:         A.social 0.005878223 0.01046822 0.1111111
## 7: yrsFrstAntiPsyRx 0.004815664 0.18662290 0.1111111
xgb.plot.importance(importance_matrix = importance_matrix)

require(DiagrammeR)
## Loading required package: DiagrammeR
xgb.plot.tree(model = fit)

4 Cross-Validation and Reliability

There’s some debate about k-folds cross-validation in validating a classifier algorithm – with the much-quoted \(k=10\) (from Kohavi’s 1995 paper) often used as a good experimental compromise for bias-variance. However, a half-split (\(k=2\)) repeated many times will be reasonable if we run enough replications as we don’t have that much data - \(N =\) 1074 and TRS cases are rare in this population.

First, a function to setup a half-split of the data and return the training and test sets.

setupSplit <- function( d ) {
  
  x.vars <- names( d )[1:(length(names(d))-1)]
  # dependent (outcome, response) variables
  y.var  <- names( d )[length(names(d))]
  
  folds <- createFolds(d$status.TRS, k = 2)
  
  # divide data into matrix of X and vector of Y for convenience
  # and convert to matrices for xgboost
  trainSet.X <- as.matrix( d[ folds$Fold1, x.vars  ] )
  trainSet.Y <- as.matrix( d[ folds$Fold1, y.var  ] )
  
  validSet.X <- as.matrix( d[ folds$Fold2, x.vars ] )
  validSet.Y <- as.matrix( d[ folds$Fold2, y.var ] )

  return( list( trainSet.X = trainSet.X, trainSet.Y = trainSet.Y, 
                validSet.X = validSet.X, validSet.Y = validSet.Y) )
}

Setup 10000 replications of the “split, train, validate” routine and their results:

# storage for results
N.repl <- 10000

resultsTab <- data.frame( sens = rep(0,N.repl),
                          spec = rep(0,N.repl),
                          error = rep(0,N.repl),   ## raw error
                          top1 = rep("-",N.repl),  ## store top 3 variables
                          top2 = rep("-",N.repl),
                          top3 = rep("-",N.repl), stringsAsFactors = FALSE )

Define a function to execute a single run:

# function to run a single split, train and test/validate
ReplicateTrainValidate <- function( thisSplit ) {

  # a row to store and return results for a bigger table
  resultRow <- data.frame( sens = 0,
                           spec = 0,
                           error = 0,
                           top1 = "-",
                           top2 = "-",
                           top3 = "-", stringsAsFactors = FALSE )

  # 1. fit the training data
  fit <- xgboost(data = thisSplit$trainSet.X, 
                 label = thisSplit$trainSet.Y, 
                 max_depth = 2, eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic",
                 verbose = 0)
  
    # 2. test on the validation set
  y_pred <- decRule( predict(fit, thisSplit$validSet.X) )

  # 3. pull out the sens and spec values
  suppressWarnings(
   cm <- confusionMatrix( y_pred, thisSplit$validSet.Y, positive = as.character(1) )
  )

  # Build a "row" of results to return ...
  # 4. the "raw" misclassification rate (fraction of incorrect)
  resultRow$error <- length( which( as.numeric( thisSplit$validSet.Y ) != y_pred ) ) /
                                     length( thisSplit$validSet.Y )

  # 5. sensitivity and specificity
  resultRow$sens <- cm$byClass["Sensitivity"]
  resultRow$spec <- cm$byClass["Specificity"]

  # 6. fetch the top 3 predictor variables
  importance_matrix <- xgb.importance(model = fit)
  # ... and store away for later
  resultRow$top1 <- as.character( importance_matrix[1,1] )
  resultRow$top2 <- as.character( importance_matrix[2,1] )
  resultRow$top3 <- as.character( importance_matrix[3,1] )

  return( resultRow )
}

And this runs N.repl = 10000 replications (warning: this takes some time to execute - around 3 minutes on an i7 desktop)

startT <- Sys.time()
pb <- txtProgressBar(min = 0, max = N.repl, initial = 0, style = 1)
for ( i in 1:N.repl ) {
  resultsTab[i,] <- ReplicateTrainValidate( setupSplit( d ) )
  setTxtProgressBar(pb, i)
}
## ===========================================================================
close(pb)
print( Sys.time() - startT )
## Time difference of 3.037827 mins

Having trained 10000 classifiers (to rule out chance performance on a “lucky” split) we examine the average performance to see if there’s anything unusual; part of this assessment will be to see which predictor variables are being used to drive classification in a majority of the 10000 replications.

Start by looking at the average behaviour of the sensitivity and specificity (where we use the median and IQR as measures of central tendency and dispersion):

SensSpecPlots <- function( resultsTab ) {

  median.sens <- median( resultsTab$sens )
  iqr.sens    <- IQR( resultsTab$sens )

  p.avSens <- ggplot( resultsTab, aes( x = sens ) ) +
    geom_histogram( bins = 15, fill = "#99CCFF" ) +
    xlab("\nSensitivity") +
    ylab("Frequency\n") +
    xlim( 0, 1.2 ) +
    geom_vline( xintercept = median.sens, size = 1, color = "red") +
    annotate("segment", x = median.sens - iqr.sens, xend = median.sens + iqr.sens,
                 y = 1000, yend = 1000, size = 1, color = "red") +
    theme_classic()

  median.spec <- median( resultsTab$spec )
  iqr.spec    <- IQR( resultsTab$spec )

  p.avSpec <- ggplot( resultsTab, aes( x = spec ) ) +
    geom_histogram( bins = 15, fill = "#FF66CC" ) +
    xlab("\nSpecificity") +
    ylab("Frequency\n") +
    xlim( 0, 1.2 ) +
    geom_vline( xintercept = median.spec, size = 1, color = "red") +
    annotate("segment", x = median.spec - iqr.spec, xend = median.spec + iqr.spec,
                 y = 1000, yend = 1000, size = 1, color = "red") +
    theme_classic()

  median.err <- median( resultsTab$error )
  iqr.err  <- IQR( resultsTab$error )

  p.avError <- ggplot( resultsTab, aes( x = error ) ) +
    geom_histogram( bins = 10, fill = "#669933" ) +
    xlab("\nMisclassification Error") +
    ylab("Frequency\n") +
    #xlim( 0, 6 ) +
    geom_vline( xintercept = median.err, size = 1, color = "red") +
    annotate("segment", x = median.err - iqr.err, xend = median.err + iqr.err,
                 y = 1000, yend = 1000, size = 1, color = "red") +
    theme_classic()

  grid.arrange( p.avSens, p.avSpec, p.avError, ncol = 3)

  # return the medians for later use
  return( c( "median.sens" = median.sens, "iqr.sens" = iqr.sens,
             "median.spec" = median.spec, "iqr.spec" = iqr.spec,
             "median.err"  = median.err,  "iqr.err"  = iqr.err) )

}

# Plot the sensitivity, specificity and misclassification errors
sensSpec <- SensSpecPlots( resultsTab )

So on average, we are obtaining out of sample performances (that is, on the validation sets) of:

  • median sensitivity = 0 with +/- IQR = 0
  • median specificity = 1 with +/- IQR = 0.002
  • median misclassification error = 0.052 with +/- IQR = 0.009

Let’s look at – on average – which predictor variables are doing the work. We have trained 10000 classifiers on different (pseudo-randomised) training-validation splits of the data.

So we can interrogate each of these classifiers to find out how frequently each predictor is the 1st, 2nd, 3rd (and so on) most important predictor in these 10000 replications. This gives us some idea of the stability of the features (predictors) in classifier performance when the training and validation data varies (a bit like a sensitivity analysis).

plotPredictors <- function( predsList, thisTitle ){

  # tabulate the top predicting perfor
  thisTop <- table( predsList )
  thisTop <- rev( thisTop[ order( thisTop )] )
  
  # build a dataframe to plot
  df <- data.frame( predictor = factor( names( thisTop ), levels = names( thisTop ), ordered = TRUE ),
                     freq      = as.numeric( thisTop ) ) 

  p.top <- ggplot( df, aes( x = predictor, y = log10(freq) ) ) +
          geom_bar( stat = "identity", fill = "#CC6600" ) +
          xlab( "\nPredictor" ) +
          ylab( "Log10 Frequency\n" ) +
          ylim( 0, log10(10000) ) +
          theme_classic() +
          theme( axis.text.x = element_text(angle = 90, hjust = 1 ) ) +
          ggtitle( thisTitle )
  return( p.top )

}

p.top1 <- plotPredictors( resultsTab$top1, "1st")
p.top2 <- plotPredictors( resultsTab$top2, "2nd")
p.top3 <- plotPredictors( resultsTab$top3, "3rd")

Examining the frequency of the “top” predictor – that is, the predictor chosen as the most important in the boosting tree algorithms over all 10000 replications – we see that:

(p.top1)

Notice how the neurocognitive tests are chosen more often, but this is not reliable – the graph should show a “sharp” drop after the first predictor (because if it were reliably behaving as the top-performing predictor variable, all other’s would be rarely selected). Similarly, looking at the 2nd most frequently chosen predictors:

(p.top2)

Similar pattern – no one predictor is behaving as the consistent second most-important predictor.

These plots tell the story: no predictor variables are consistently selected, and the sensitivity / specificity plots are clearly degenerate. For the rare cases of TRS – the sensitivity plot tells us the classifiers never actually spot them. The misclassification graph looks (superficially) impressive, but a classfier that is very specific in the context of rare positive cases needs to have very good sensitivity otherwise – as in this case – it’s just very good at spotting non-TRS cases (which as they are in a majority by a 2 orders of magnitude), is not all that surprising).

5 Classifier Performance Compared to a Null Distribution

For completeness, we’ll run a similar analysis to generate a distribution of sensitivity, specificity and misclassification for 10000 classifiers (trained/validated on the same \(k\)=2 split regime) where the classifications (TRS, non-TRS) are randomly permuted – so we can compute an empirical p-value for the actual classifiers performing better than chance. Again, this takes some time.

5.1 Permutation Testing

# K = number of randomised data sets to use in estimating p-value
K <- 10000

# storage for results
nullResultsTab <- data.frame( sens = rep(0,K),
                              spec = rep(0,K),
                              error = rep(0,K),
                              top1 = rep("-",K),  ## store top 3 variables 
                              top2 = rep("-",K),
                              top3 = rep("-",K), stringsAsFactors = FALSE )

startT <- Sys.time()
pb <- txtProgressBar(min = 0, max = N.repl, initial = 0, style = 1)
for ( i in 1:K ) {
  # randomise the class labels
  d.rand <- d
  d.rand$group <- sample( d.rand$group )
  nullResultsTab[i,] <- ReplicateTrainValidate( setupSplit( d.rand ) )
  setTxtProgressBar(pb, i)
}
## ===========================================================================
close(pb)
print( Sys.time() - startT)
## Time difference of 3.096146 mins

Now we have an 10000 estimates for the error (and feature/predictors structure of each classifier) under the null hypothesis that the classifier is learning nothing from the data (because it is randomised).

For each of the original classifier’s (trained on the real class-labels), we count the number of times the misclassification error is less than or equal to random classifiers (i.e. the error on the random data is better than on the real classifiers) – and this is the empirical \(p\)-value of the classifier performing differently to random: \[ \frac{\# \left( \text{Err}(F,D_{V}^{*}) \leq \text{Err}(F,D_{V}) \right)+1}{k+1} \] Where \(D_{*}^{V}\) denotes the validation set from data with a random permutation of the class labels, and \(D_{V}\) denotes the validation data on one of the “real” split train/validation classifiers. \(\text{Err}(F,D_{V}^{*})\) is then the misclassification error of a random classifier, and \(\text{Err}(F,D_{V})\) is the misclassification error on the an actual classifier; in all cases, note that the validation set performance is used (not the training set).

We implement this, and then compute the average of all k = 10000 p-values to arrive at a final hypothesis test. The misclassification error is a coarse test of performance, because it doesn’t account for false positives / negatives, so in a similar way, we’ll compute the empirical p-values for sensitivity and specificity:

pTab   <- rep(NA,K)
pSens  <- rep(NA,K)
pSpec  <- rep(NA,K)
for ( i in 1:nrow(resultsTab) ) {
  pTab[i]  <- (length( which( nullResultsTab$error <= resultsTab$error[i] ) ) + 1 ) / ( K + 1 )
  pSens[i] <- (length( which( nullResultsTab$sens  >= resultsTab$sens[i] ) ) + 1 ) / ( K + 1 )
  pSpec[i] <- (length( which( nullResultsTab$spec  >= resultsTab$spec[i] ) ) + 1 ) / ( K + 1 )
}

final.pValueErr  <- mean( pTab )
final.pValueSens <- mean( pSens )
final.pValueSpec <- mean( pSpec )

The results:

  • On misclassification error, we arrive at an average \(p\)-value of 0.5451

And for sensitivity and specificity:

  • Average \(p\)-value for sensitivity = 0.9654
  • Average \(p\)-value for specificity = 0.7135

So we can conclude that the real classification algorithm is performing no better than random (e.g. the null distribution of random classifiers). To see why, we can plot the distributions of sensitivity and specificity for the real classifiers and the random classifiers:

nullResultsTab$lbl <- "Null"
resultsTab$lbl  <- "Actual" 

allResults <- rbind( resultsTab, nullResultsTab )

allResults$lbl <- factor( allResults$lbl )

save( allResults, nullResultsTab, resultsTab, file = "distributionClassificationPerformance.RData")

sens.plot <- ggplot( allResults, aes( x = sens, fill = lbl ) ) +
  geom_histogram( position = "identity", bins = 20, alpha = 0.5 ) +
  xlab("\nSensitivity") +
  ylab("Frequency\n") +
  theme_classic() +
  theme(legend.position="none")


spec.plot <- ggplot( allResults, aes( x = spec, fill = lbl ) ) +
  geom_histogram( position = "identity", bins = 20, alpha = 0.3 ) +
  xlab("\nSpecificity") +
  ylab("Frequency\n") +
  theme_classic() +
  theme(legend.position=c(0,1))

grid.arrange( sens.plot, spec.plot, ncol = 2 )

The distributions almost competely overlap. It is informative to look at the random classifier’s use of the predictors, to really understand if the structure shown for the real classifiers (where PANSS_pos was the most imporant predictor) is different from the random classifiers:

p.top1 <- plotPredictors( nullResultsTab$top1, "1st")
p.top2 <- plotPredictors( nullResultsTab$top2, "2nd")
p.top3 <- plotPredictors( nullResultsTab$top3, "3rd")
  
grid.arrange(p.top1, p.top2, p.top3, ncol = 3)

These distributions of the 1st, 2nd and 3rd most important predictors/features is similar to that for the actual classifiers, as we would predict if there are no concrete relationships between the predictors/class labels that can be exploited.