rm(list = ls())

library(survival)
library(rstpm2)
## Loading required package: splines
## 
## Attaching package: 'rstpm2'
## The following object is masked from 'package:survival':
## 
##     colon
# Load data
data(lung)
## Warning in data(lung): data set 'lung' not found
# Fitting status: 1 = event (death), 0 = censored
lung$status <- ifelse(lung$status == 2, 1, 0)

# Inspect data
head(lung)
##   inst time status age sex ph.ecog ph.karno pat.karno meal.cal wt.loss
## 1    3  306      1  74   1       1       90       100     1175      NA
## 2    3  455      1  68   1       0       90        90     1225      15
## 3    3 1010      0  56   1       0       90        90       NA      15
## 4    5  210      1  57   1       1       90        60     1150      11
## 5    1  883      1  60   1       0      100        90       NA       0
## 6   12 1022      0  74   1       1       50        80      513       0
str(lung)
## 'data.frame':    228 obs. of  10 variables:
##  $ inst     : num  3 3 3 5 1 12 7 11 1 7 ...
##  $ time     : num  306 455 1010 210 883 ...
##  $ status   : num  1 1 0 1 1 0 1 1 1 1 ...
##  $ age      : num  74 68 56 57 60 74 68 71 53 61 ...
##  $ sex      : num  1 1 1 1 1 1 2 2 1 1 ...
##  $ ph.ecog  : num  1 0 0 1 0 1 2 2 1 2 ...
##  $ ph.karno : num  90 90 90 90 100 50 70 60 70 70 ...
##  $ pat.karno: num  100 90 90 60 90 80 60 80 80 70 ...
##  $ meal.cal : num  1175 1225 NA 1150 NA ...
##  $ wt.loss  : num  NA 15 15 11 0 0 10 1 16 34 ...
summary(lung)
##       inst            time            status            age       
##  Min.   : 1.00   Min.   :   5.0   Min.   :0.0000   Min.   :39.00  
##  1st Qu.: 3.00   1st Qu.: 166.8   1st Qu.:0.0000   1st Qu.:56.00  
##  Median :11.00   Median : 255.5   Median :1.0000   Median :63.00  
##  Mean   :11.09   Mean   : 305.2   Mean   :0.7237   Mean   :62.45  
##  3rd Qu.:16.00   3rd Qu.: 396.5   3rd Qu.:1.0000   3rd Qu.:69.00  
##  Max.   :33.00   Max.   :1022.0   Max.   :1.0000   Max.   :82.00  
##  NA's   :1                                                        
##       sex           ph.ecog          ph.karno        pat.karno     
##  Min.   :1.000   Min.   :0.0000   Min.   : 50.00   Min.   : 30.00  
##  1st Qu.:1.000   1st Qu.:0.0000   1st Qu.: 75.00   1st Qu.: 70.00  
##  Median :1.000   Median :1.0000   Median : 80.00   Median : 80.00  
##  Mean   :1.395   Mean   :0.9515   Mean   : 81.94   Mean   : 79.96  
##  3rd Qu.:2.000   3rd Qu.:1.0000   3rd Qu.: 90.00   3rd Qu.: 90.00  
##  Max.   :2.000   Max.   :3.0000   Max.   :100.00   Max.   :100.00  
##                  NA's   :1        NA's   :1        NA's   :3       
##     meal.cal         wt.loss       
##  Min.   :  96.0   Min.   :-24.000  
##  1st Qu.: 635.0   1st Qu.:  0.000  
##  Median : 975.0   Median :  7.000  
##  Mean   : 928.8   Mean   :  9.832  
##  3rd Qu.:1150.0   3rd Qu.: 15.750  
##  Max.   :2600.0   Max.   : 68.000  
##  NA's   :47       NA's   :14
# Cox Model

cox_fit <- coxph(Surv(time, status) ~ sex + age, data = lung)
summary(cox_fit)
## Call:
## coxph(formula = Surv(time, status) ~ sex + age, data = lung)
## 
##   n= 228, number of events= 165 
## 
##          coef exp(coef)  se(coef)      z Pr(>|z|)   
## sex -0.513219  0.598566  0.167458 -3.065  0.00218 **
## age  0.017045  1.017191  0.009223  1.848  0.06459 . 
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
##     exp(coef) exp(-coef) lower .95 upper .95
## sex    0.5986     1.6707    0.4311    0.8311
## age    1.0172     0.9831    0.9990    1.0357
## 
## Concordance= 0.603  (se = 0.025 )
## Likelihood ratio test= 14.12  on 2 df,   p=9e-04
## Wald test            = 13.47  on 2 df,   p=0.001
## Score (logrank) test = 13.72  on 2 df,   p=0.001
#Royston–Parmar Model

rp_fit <- stpm2(Surv(time, status) ~ sex + age, data = lung, df = 4)
summary(rp_fit)
## Maximum likelihood estimation
## 
## Call:
## stpm2(formula = Surv(time, status) ~ sex + age, data = lung, 
##     df = 4)
## 
## Coefficients:
##                           Estimate Std. Error z value     Pr(z)    
## (Intercept)             -5.5286249  0.9391697 -5.8867 3.939e-09 ***
## sex                     -0.5072641  0.1672622 -3.0327  0.002423 ** 
## age                      0.0162050  0.0091916  1.7630  0.077897 .  
## nsx(log(time), df = 4)1  4.4255382  0.6852237  6.4585 1.057e-10 ***
## nsx(log(time), df = 4)2  4.4612587  0.4837175  9.2229 < 2.2e-16 ***
## nsx(log(time), df = 4)3  7.8840443  1.3182551  5.9807 2.222e-09 ***
## nsx(log(time), df = 4)4  4.9501240  0.3775678 13.1106 < 2.2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## -2 log L: 2292.878
#Compare Coefficients

cox_coefs <- coef(cox_fit)
rp_coefs <- coef(rp_fit)[c("sex", "age")]

comparison <- data.frame(
  Variable = names(cox_coefs),
  Cox = cox_coefs,
  RP = rp_coefs
)

print(comparison)
##     Variable         Cox          RP
## sex      sex -0.51321852 -0.50726415
## age      age  0.01704533  0.01620497
#  Baseline Hazard Comparison

# Cox cumulative baseline hazard
bh <- basehaz(cox_fit, centered = FALSE)

# Time grid for RP model
t_grid <- seq(1, max(lung$time), length.out = 200)
newdata <- data.frame(
  age = mean(lung$age),
  sex = 1,
  time = t_grid
)

# RP cumulative hazard
rp_cumhaz <- predict(rp_fit,
                     newdata = newdata,
                     type = "cumhaz")

# Plot
plot(bh$time, bh$hazard,
     type = "s",
     col = "blue",
     lwd = 2,
     xlab = "Time",
     ylab = "Cumulative Hazard",
     main = "Cox vs Royston–Parmar")

lines(t_grid, rp_cumhaz,
      col = "red",
      lwd = 2)

legend("topright",
       legend = c("Cox", "RP"),
       col = c("blue","red"),
       lwd = 2)

#Hazard Comparison
# Cox hazard 
cox_haz <- c(NA, diff(bh$hazard) / diff(bh$time))

# RP hazard
rp_haz <- predict(rp_fit,
                  newdata = newdata,
                  type = "haz")

# Plot hazards
plot(bh$time, cox_haz,
     type = "l",
     col = "blue",
     lwd = 2,
     xlab = "Time",
     ylab = "Hazard",
     main = "Hazard: Cox vs RP")

lines(t_grid, rp_haz,
      col = "red",
      lwd = 2)

legend("topright",
       legend = c("Cox", "RP"),
       col = c("blue","red"),
       lwd = 2)