#
# logistic_survival_method_example.R
#
# Reed Sorensen
# December 2017
#

library(survival)
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
rm(list = ls())

# load data from a trial of ovarian cancer treatments
df <- survival::ovarian %>%
  mutate(id = 1:nrow(.))

print(df)
##    futime fustat     age resid.ds rx ecog.ps id
## 1      59      1 72.3315        2  1       1  1
## 2     115      1 74.4932        2  1       1  2
## 3     156      1 66.4658        2  1       2  3
## 4     421      0 53.3644        2  2       1  4
## 5     431      1 50.3397        2  1       1  5
## 6     448      0 56.4301        1  1       2  6
## 7     464      1 56.9370        2  2       2  7
## 8     475      1 59.8548        2  2       2  8
## 9     477      0 64.1753        2  1       1  9
## 10    563      1 55.1781        1  2       2 10
## 11    638      1 56.7562        1  1       2 11
## 12    744      0 50.1096        1  2       1 12
## 13    769      0 59.6301        2  2       2 13
## 14    770      0 57.0521        2  2       1 14
## 15    803      0 39.2712        1  1       1 15
## 16    855      0 43.1233        1  1       2 16
## 17   1040      0 38.8932        2  1       2 17
## 18   1106      0 44.6000        1  1       1 18
## 19   1129      0 53.9068        1  2       1 19
## 20   1206      0 44.2055        2  2       1 20
## 21   1227      0 59.5890        1  2       2 21
## 22    268      1 74.5041        2  1       2 22
## 23    329      1 43.1370        2  1       1 23
## 24    353      1 63.2192        1  2       2 24
## 25    365      1 64.4247        2  2       1 25
## 26    377      0 58.3096        1  2       1 26
interval <- 100 # define interval for discretizing time

# create data frame with all time periods of interest
df2 <- expand.grid(
  futime1 = seq(0, ceiling(max(df$futime)/interval)*interval, by = interval) ) %>%
  mutate(futime2 = lead(futime1, 1)) %>%
  filter(!is.na(futime2))

print(df2)
##    futime1 futime2
## 1        0     100
## 2      100     200
## 3      200     300
## 4      300     400
## 5      400     500
## 6      500     600
## 7      600     700
## 8      700     800
## 9      800     900
## 10     900    1000
## 11    1000    1100
## 12    1100    1200
## 13    1200    1300
# for each subject, record their status at each time interval
# -- exclude time intervals after the event/censoring

df3 <- bind_rows(lapply(1:nrow(df), function(i) {
  
  df[rep(i, nrow(df2)), ] %>%
    select(id, futime, fustat, age, rx, ecog_ps = ecog.ps) %>%
    cbind(df2, .) %>%
    mutate(futime_mid = futime1 + interval/2) %>%
    mutate(fustat = ifelse(df[i,"futime"] > futime2, 0, fustat)) %>%
    filter(futime > futime1)
  
}))

print(head(df3, n = 30))
##    futime1 futime2 id futime fustat     age rx ecog_ps futime_mid
## 1        0     100  1     59      1 72.3315  1       1         50
## 2        0     100  2    115      0 74.4932  1       1         50
## 3      100     200  2    115      1 74.4932  1       1        150
## 4        0     100  3    156      0 66.4658  1       2         50
## 5      100     200  3    156      1 66.4658  1       2        150
## 6        0     100  4    421      0 53.3644  2       1         50
## 7      100     200  4    421      0 53.3644  2       1        150
## 8      200     300  4    421      0 53.3644  2       1        250
## 9      300     400  4    421      0 53.3644  2       1        350
## 10     400     500  4    421      0 53.3644  2       1        450
## 11       0     100  5    431      0 50.3397  1       1         50
## 12     100     200  5    431      0 50.3397  1       1        150
## 13     200     300  5    431      0 50.3397  1       1        250
## 14     300     400  5    431      0 50.3397  1       1        350
## 15     400     500  5    431      1 50.3397  1       1        450
## 16       0     100  6    448      0 56.4301  1       2         50
## 17     100     200  6    448      0 56.4301  1       2        150
## 18     200     300  6    448      0 56.4301  1       2        250
## 19     300     400  6    448      0 56.4301  1       2        350
## 20     400     500  6    448      0 56.4301  1       2        450
## 21       0     100  7    464      0 56.9370  2       2         50
## 22     100     200  7    464      0 56.9370  2       2        150
## 23     200     300  7    464      0 56.9370  2       2        250
## 24     300     400  7    464      0 56.9370  2       2        350
## 25     400     500  7    464      1 56.9370  2       2        450
## 26       0     100  8    475      0 59.8548  2       2         50
## 27     100     200  8    475      0 59.8548  2       2        150
## 28     200     300  8    475      0 59.8548  2       2        250
## 29     300     400  8    475      0 59.8548  2       2        350
## 30     400     500  8    475      1 59.8548  2       2        450
# normal Cox-PH regression
fit1 <- coxph(Surv(futime, fustat) ~ age + rx, data = df)
summary(fit1)
## Call:
## coxph(formula = Surv(futime, fustat) ~ age + rx, data = df)
## 
##   n= 26, number of events= 12 
## 
##         coef exp(coef) se(coef)      z Pr(>|z|)   
## age  0.14733   1.15873  0.04615  3.193  0.00141 **
## rx  -0.80397   0.44755  0.63205 -1.272  0.20337   
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
##     exp(coef) exp(-coef) lower .95 upper .95
## age    1.1587      0.863    1.0585     1.268
## rx     0.4475      2.234    0.1297     1.545
## 
## Concordance= 0.798  (se = 0.091 )
## Rsquare= 0.457   (max possible= 0.932 )
## Likelihood ratio test= 15.89  on 2 df,   p=0.0003551
## Wald test            = 13.47  on 2 df,   p=0.00119
## Score (logrank) test = 18.56  on 2 df,   p=9.341e-05
# logistic version, time dummies
fit3 <- glm(fustat ~ age + rx + as.factor(futime_mid), family = "binomial", data = df3)
summary(fit3)
## 
## Call:
## glm(formula = fustat ~ age + rx + as.factor(futime_mid), family = "binomial", 
##     data = df3)
## 
## Deviance Residuals: 
##      Min        1Q    Median        3Q       Max  
## -1.48233  -0.32303  -0.16494  -0.00015   2.66590  
## 
## Coefficients:
##                             Estimate Std. Error z value Pr(>|z|)   
## (Intercept)                -13.40287    4.10284  -3.267  0.00109 **
## age                          0.18097    0.05528   3.274  0.00106 **
## rx                          -1.07715    0.72976  -1.476  0.13994   
## as.factor(futime_mid)150     1.28263    1.44784   0.886  0.37568   
## as.factor(futime_mid)250     1.03090    1.66123   0.621  0.53489   
## as.factor(futime_mid)350     3.14903    1.50998   2.085  0.03703 * 
## as.factor(futime_mid)450     3.55940    1.55300   2.292  0.02191 * 
## as.factor(futime_mid)550     3.16643    1.80476   1.754  0.07935 . 
## as.factor(futime_mid)650     3.26987    1.80848   1.808  0.07059 . 
## as.factor(futime_mid)750   -13.69619 3219.69831  -0.004  0.99661   
## as.factor(futime_mid)850   -13.37454 3831.95231  -0.003  0.99722   
## as.factor(futime_mid)950   -13.58279 4525.88922  -0.003  0.99761   
## as.factor(futime_mid)1050  -13.58279 4525.88922  -0.003  0.99761   
## as.factor(futime_mid)1150  -13.78092 5103.45002  -0.003  0.99785   
## as.factor(futime_mid)1250  -14.00089 7024.23929  -0.002  0.99841   
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 86.754  on 169  degrees of freedom
## Residual deviance: 59.827  on 155  degrees of freedom
## AIC: 89.827
## 
## Number of Fisher Scoring iterations: 18
# compare RR (Cox) and OR (logistic) for age
exp(coef(fit1))["age"] # RR
##      age 
## 1.158732
exp(coef(fit3))["age"] # OR
##     age 
## 1.19838
# compare RR and OR for treatment
exp(coef(fit1))["rx"]
##        rx 
## 0.4475473
exp(coef(fit3))["rx"]
##        rx 
## 0.3405658
#####
# Get the predicted probability of ovarian cancer
# for each subject at their time of last observation
#

# from the help docs for predict.coxph
# "The survival probability for a subject is equal to exp(-expected)."
# -- I do  1-exp(-1*expected)  to get the failure probability
options(scipen = 999)
df$prob_coxph <- 1 - exp(-1 * predict(fit1, newdata = df, type = "expected"))
df3$prob_timedummies <- predict(fit3, newdata = df3, type = "response")

df4 <- df3 %>%
  group_by(id) %>%
  summarize( # take the product of conditional survival probs
    prob_timedummies = 1 - prod(1-prob_timedummies) 
  )

df5 <- df %>%
  left_join(df4[, c("id", "prob_timedummies")], by = "id") %>%
  mutate(
    prob_coxph = round(prob_coxph, digits = 3),
    prob_timedummies = round(prob_timedummies, digits = 3) ) %>%
  select(id, futime, fustat, age, rx, prob_coxph, prob_timedummies)
  
# compare the predicted probabilities
# according to Cox PH and logistic models
print(df5) 
##    id futime fustat     age rx prob_coxph prob_timedummies
## 1   1     59      1 72.3315  1      0.166            0.199
## 2   2    115      1 74.4932  1      0.426            0.686
## 3   3    156      1 66.4658  1      0.263            0.298
## 4   4    421      0 53.3644  2      0.090            0.160
## 5   5    431      1 50.3397  1      0.161            0.251
## 6   6    448      0 56.4301  1      0.349            0.544
## 7   7    464      1 56.9370  2      0.231            0.275
## 8   8    475      1 59.8548  2      0.389            0.408
## 9   9    477      0 64.1753  1      0.875            0.903
## 10 10    563      1 55.1781  2      0.283            0.277
## 11 11    638      1 56.7562  1      0.699            0.767
## 12 12    744      0 50.1096  2      0.183            0.159
## 13 13    769      0 59.6301  2      0.560            0.590
## 14 14    770      0 57.0521  2      0.429            0.440
## 15 15    803      0 39.2712  1      0.087            0.070
## 16 16    855      0 43.1233  1      0.149            0.134
## 17 17   1040      0 38.8932  1      0.083            0.065
## 18 18   1106      0 44.6000  1      0.181            0.170
## 19 19   1129      0 53.9068  2      0.297            0.286
## 20 20   1206      0 44.2055  2      0.081            0.058
## 21 21   1227      0 59.5890  2      0.558            0.587
## 22 22    268      1 74.5041  1      0.778            0.846
## 23 23    329      1 43.1370  1      0.025            0.038
## 24 24    353      1 63.2192  2      0.263            0.356
## 25 25    365      1 64.4247  2      0.382            0.413
## 26 26    377      0 58.3096  2      0.177            0.177
with(df5, plot(prob_coxph, prob_timedummies))

# Notes:
# convert probability to rate: -ln(1-p) / t
# convert rate to probability: 1 - exp(-rt)
#