library(tidyverse)
library(reticulate)
library(forcats)
library(lubridate)
library(knitr)
library(vroom)
library(labelled)
library(gtsummary)
library(ggpubr)
library(RColorBrewer)
library(lemon)
library(paletteer)
library(dbplot)
library(survival)
library(survminer)
library(bshazard)
library(magick)
library(cowplot)
library(rms)

options("encoding" = "UTF-8")
knitr::opts_chunk$set(echo = TRUE, message = FALSE, warning = FALSE)

list("style_number-arg:big.mark" = "") %>%
  set_gtsummary_theme()
trans <- read_csv("D:/MyDocuments/Kaggle/Kkbox/data/trans_sample.csv", 
                  col_types = cols(transaction_date = col_date(),
                                   membership_expire_date = col_date(),
                                   is_cancel = col_logical(),
                                   is_auto_renew = col_logical(),
                                   payment_method_id = col_factor()))
trans <- distinct(trans) %>%
  mutate(payment_method_id = fct_infreq(payment_method_id)) %>%
  arrange(msno, transaction_date)

members <- read_csv("D:/MyDocuments/Kaggle/Kkbox/data/members_sample.csv",
                    col_types = cols(registration_init_time = col_date(),
                                     city = col_factor(),
                                     gender = col_factor(),
                                     registered_via = col_factor()))
members <- members %>%
  mutate(city = fct_infreq(city),
         registered_via = fct_infreq(registered_via),
         gender = fct_explicit_na(factor(gender, c("female", "male"), 
                                         c("Female", "Male")),
                                  "Unknown"),
         year_reg = factor(year(registration_init_time)),
         month_reg = factor(month(registration_init_time),
                            labels = month.abb))

ulogs <- read_csv("D:/MyDocuments/Kaggle/Kkbox/data/ulogs_sample.csv",
                  col_types = cols(date = col_date("%Y%m%d")))
ulogs$num_tot <- pmap_dbl(ulogs[, paste0("num_", c(25,50,75,985,100))], sum)
ulogs <- ulogs %>%
  mutate(av_times_played = num_tot / num_unq,
         num_repeats = num_tot - num_unq, 
         sh_25 = num_25 / num_tot,
         sh_50 = num_50 / num_tot,
         sh_75 = num_75 / num_tot,
         sh_985 = num_985 / num_tot,
         sh_100 = num_100 / num_tot,
         av_pct_played = (num_25*0.125 + num_50*0.375 + num_75*0.625 + num_985*0.8675 + num_100*0.9925)/num_tot*100)

var_label(trans) <- list(payment_method_id = "Payment method",
                         payment_plan_days = "Payment plan days",
                         plan_list_price = "Plan list price",
                         actual_amount_paid = "Actual amount paid",
                         is_auto_renew = "Auto-renew",
                         is_cancel = "Cancellation")

var_label(members) <- list(city = "City",
                           bd = "Age",
                           gender = "Gender",
                           registered_via = "Registration method",
                           year_reg = "Year of registration at the service",
                           month_reg = "Month of registration at the service")

var_label(ulogs) <- list(num_25 = "Number of songs played less than 25% of the song length",
                         num_50 = "Number of songs played between 25% to 50% of the song length", 
                         num_75 = "Number of songs played between 50% to 75% of the song length", 
                         num_985 = "Number of songs played between 75% to 98.5% of the song length",  
                         num_100 = "Number of songs played over 98.5% of the song length",
                         num_tot = "Total number of songs played",
                         num_unq = "Number of unique songs played",
                         av_times_played = "Average number of times of listening to one song",
                         num_repeats = "Number of repeats",
                         sh_25 = "Share of songs played less than 25% of the song length",
                         sh_50 = "Share of songs played between 25% to 50% of the song length", 
                         sh_75 = "Share of songs played between 50% to 75% of the song length", 
                         sh_985 = "Share of songs played between 75% to 98.5% of the song length",  
                         sh_100 = "Share of songs played over 98.5% of the song length",
                         av_pct_played = "Average % of one song's length played",
                         total_secs = "Total seconds played")

1 Introduction


This study project is aimed at prediction of the probability of customers’ attrition (churn) with methods from survival analysis. In comparison with standard classification approaches resulted in prediction of churn as a binary target variable or probability of churn over some fixed period of time, survival analysis can be useful in understanding the dynamic of customer retention and attrition over time after some starting point (usually the start of relationship - contract, subscription, etc.). It can help in answering such questions as when clients churn the most, whether there are periods of increasing, stable or decreasing churn rates, as well as whether these all are different depending on customers’ characteristics, their behavior or some contract options. In addition, it is applied in estimating customer lifetime value for a company.

The dataset for the project was taken from Kaggle competition “WSDM - KKBox’s Churn Prediction Challenge” which was held in 2017. KKBox is an Asian company which offers subscription based music streaming service. The criteria of churn used in the competition is no new valid service subscription within 30 days after the current membership expires. We will use the same definition, although we will predict not just the probability of churn during the next 30 days after membership expiration as a binary outcome, but the distribution of time to churn and retention rates for different time points after the start of subscription. We will use models from both statistical analysis and machine learning, as well as both baseline features (available at the start of subscription) and time-dependent variables (information on user’s behavior during subscription).


2 Literature review: survival analysis


Survival analysis dates back to the 17th century when cross-sectional life tables were developed by John Graunt and Edmond Halley. There were further developments by Daniel Bernoulli in 1760, when he computed life expectancy using life tables. These approaches are still used in demography and epidemiology. Throughout the centuries, survival analysis was solely linked to the investigation of mortality rates; however, in the last few decades, applications of its methods have been extended beyond biomedical research to other fields such as criminology, sociology, marketing, economics, institutional research, health insurance and HR practice (Camilleri 2019).


2.1 Advantages of survival analysis in churn prediction


Our methodology for predicting customers’ churn will be based on the approach of survival analysis, for which the main variable of interest is the (expected) time to some event. In addition, for each observation we have a status variable, taking value 1 if the event happened at the end of the corresponding time period, or 0 otherwise. Latter observations are called censored: it means that we observed them for the whole time period without the event, but we admit that this event can happen for them in the future (after the period of observation), so the time to event for them will be censored on the right and will be denoted as observation time+. If no observations were censored, we could use standard regression models for continuous responses to predict time to event or models for binary classification to predict the probability of event during some preliminary specified period of time (for example, one year after beginning of observation). Besides taking censoring into account, in comparison with these models survival analysis have several advantages (Harrell 2015):

  • Of course, above all, it is a priority option of choice when time to some event is of the main interest. Knowing its value could be also of use in estimating user’s lifetime value. In addition, the probability of surviving past a certain time (for example, the probability that a user will be still active after a month or a year of subscription) is often more relevant than the expected survival time, and survival analysis has a potential in predicting this probability.

  • Time to event can have an unusual distribution: it is restricted to be positive, so it has a skewed distribution and will never be normally distributed. It is not a problem for survival analysis, but can be a problem to standard models with continuous target variable.

  • One of the functions used in survival analysis, the hazard function, helps to understand the mechanism of failure (in our case the mechanism of churn) and how it changes over the survival time (period of subscription). Moreover, we can predict patterns in outcome (churn behavior) over time for individual subjects.

    For example, in dependence of subscription service peculiarities (e.g. prepaid or postpaid, free-trial periods, subscription plans, etc.) there can be different types of hazard functions (Vargas 2018):


2.2 Conceptual framework of survival analysis


In our study the event is a churn, i.e. a membership expiration such that during the next 30 days it is not reactivated. It means that if 30 days after the end of the subscription period haven’t passed yet, we are not able to tell whether this client churned, and we should consider him/her as censored. Our target variable is time to event (T) - the time between the start and the end of subscription. We will define the principles of determining the start and the end points of subscription in more detail below.

We will assume that T is a continuous random variable with probability density function f(t) and cumulative distribution function \(F(t) = P(T < t)\), giving the probability that the event has occurred by time t. In survival analysis the statistical model for time to event (T) is usually defined with the complement of cumulative density function (Rodríguez 2007), called a survival function S(t). Survival function is the probability that an individual survives from the time origin to a specified time t or, equivalently, that the event has not occurred by time t and will (may) occur after it. In our analysis it will be the probability that a user remains active from the start of his/her subscription to a specified time point t (or probability that he/she churn some time after t): \[S(t) = P(T>t) = 1-F(t)=\int_t^\infty f(t)\] Survival function is non-increasing with t and equals 1 at t = 0.

Another important term in survival analysis is hazard h(t) - the probability that an individual who is still observed at time t has an event in small interval after t: \[h(t)=\lim_{u \to 0}\frac{P(t<T<t+u | T>t)}{u} = \frac{f(t)}{S(t)}=-\frac{d}{dt}logS(t)\] Hazard will characterize the intensity of churn over time, i.e. a rate of churn per unit of time. As the width of the interval (u) goes to zero, we obtain an instantaneous rate of churn for some specific time point.

The numerator in the first part of the above expression is the conditional probability that the churn will occur in the small interval after t ([t, t+u)) given that it has not occurred before, and the denominator is the width of the interval (u). The second part of the formula says that the hazard rate of churn at time t equals the density of churns at t, divided by the probability of membership lasting to that point in time without churning. In addition, it can be proven that h(t) equals the minus slope of the log of the survival function. The latter gives us the formula (Rodríguez 2007): \[S(t) = exp(-\int_0^t h(v)dv)\]

At last, the area under the curve for hazard h(t) is a cumulative hazard function (H(t)): \[H(t) = \int_0^t h(v)dv = -logS(t)\] It presents the sum of the risks you face going from duration 0 to t (Rodríguez 2007). From the last equation we have \[S(t) = exp(-H(t))\]

Therefore, knowing any one of the functions S(t), H(t), or h(t) allows one to derive the other two functions. These three functions are different ways of describing the same distribution for the time to event.


One of the problems in survival analysis is that the form of the true population survival distribution function S(t) is almost always unknown. As a consequence, we should use some estimate for it. If there is no censoring, we can use empirical distribution function of time to event, \(F_n(t)\), to nonparametrically estimate survival as \(S_n(t) = 1 - F_n(t) = [\text{number of }T_i > t]/n\) (Harrell 2015). With censoring taking place we can use either other nonparametric estimates (the most frequently used here are Kaplan–Meier product-limit estimator and Nelson-Aalen estimator) or parametrically estimate the survival function (it means that we should specify a functional form for S(t) and estimate its unknown parameters). Parametric models give more precise and smooth estimates for survival, hazard and cumulative hazard, allows to easily compute quantiles for time to event and to predict it, but they require accurate specification of distribution function.

Below the example of Kaplan-Meier survival curve is given (Kazmi 2019):

There are several types of most frequently used parametric models for survival function: Weibull, exponential (it is a Weibull distribution with shape = 1), log-normal, log-logistic and gamma. For more examples of distributions and their characteristics see Herndon (1988). These distributions can be useful for modelling different patterns of hazard behavior: for example, exponential - for the hazard that is constant in time, Weibull, Gompertz, gamma – for monotonically decreasing or increasing hazard, while log-logistic, log-normal, generalized gamma and generalized F can be used for modelling bathtub-shaped hazard functions.

Here are several examples of survival, hazard and density plots for four of these distributions with different parameters (Emmert-Streib and Dehmer 2019):

For gamma distribution:

In more general case we can use any distribution defined for \(t \in [0,\infty)\) or for transformed time to event (usually log transformed) or models derived from spline smoothing for hazard functions.


For estimating the unknown parameters of survival function S(t) maximum likelihood estimation is applied. But is has some peculiarities in the presence of censoring, i.e. unobserved actual time to event for some subjects.

If for the i-th subject we denote the censoring time as \(C_i\) and the time to event as \(T_i\), then we can define the event indicator as: \[e_i = \begin{cases} 1, & if \text{ } T_i \leq C_i\\ 0, & if \text{ } T_i > C_i \end{cases}\] and the observed target variable as: \[Y_i = min(T_i, C_i)\]

Thus, the pair \((Y_i, e_i)\) contains all the information we need to know about the target variable, and it is usually used in realization of survival models in different libraries in Python and R.

If \(Y_i\) is uncensored, observation i contributes a factor to the likelihood equal to the density function for T evaluated at \(Y_i\), \(f(Y_i)\). If \(Y_i\) represents a censored time, it is only known that \(T_i > Y_i\). The contribution to the likelihood function is then the probability that \(T_i > C_i\), which equals to probability that \(T_i > Y_i\). This probability is \(S(Y_i)\). So the joint likelihood over n observations is (Harrell 2015): \[L = \prod_{i:Y_i\text{ uncensored}}^{n} f(Y_i) \prod_{i:Y_i\text{ censored}}^{n} S(Y_i)\] (! in the assumption of non-informative, i.e. independent of the risk of the event, censoring). The model is estimated by maximizing this likelihood or, equivalently, the log likelihood: \[logL = \sum_{i:Y_i\text{ uncensored}}^n log (h(Y_i)) - \sum_{i=1}^n H(Y_i)\]

All observations then contribute an amount to the log likelihood equal to the negative of the cumulative hazard evaluated at the event/ censoring time. In addition, uncensored observations contribute an amount equal to the log of the hazard function evaluated at the time of event. Once L or log L is specified, the general maximum likelihood methods can be used without change in most situations. The principal difference is that censored observations contribute less information to the statistical inference than uncensored ones.


2.3 Prediction models in survival analysis


Of course, if we are to predict a time to event or survival distribution, we usually would like to include some predictors into our analysis. It could be done by generalizing survival model into survival regression model (which is typical for traditional statistical analysis) and estimate it using statistical or machine learning based methods. Statistical methods focus more on characterizing both the distributions of the event times and the statistical properties of the parameter estimation. Machine learning methods are usually less conservative in their assumptions. The taxonomy of all methods developed for survival analysis is presented below (Wang, Li, and Reddy 2019):


Overall, survival models are used to predict one or all of the following measures (Sonabend 2021):

  • The time to event: \[g: X \to T\] where \(X \subseteq \mathbb{R}^m\) is a matrix for m features, \(T\) - predicted time to event, \(g\) - a prediction functional, which is trained on the training set \((X,Y,e)\) with the observed time \(Y \subseteq \mathbb{R}_{\geq 0}\) (some models assume that \(Y \subseteq \mathbb{R}_{> 0}\), so for zero observed time some small value should be added) and an indicator for event (1) and censoring (0) status \(e = \{0,1\}\). It is the deterministic kind of task.

  • The relative risk of an individual experiencing an event: \[X \to \mathcal{R}\] where \(\mathcal{R} \subseteq \mathbb{R}\) is a relative risk. It is a ranking kind of task. By themselves relative risks have not a meaningful interpretation, but can be used to rank and/ or cluster individuals by relative risk. Time-to-event predictions can be seen as a special-case of the ranking problem as an individual with a predicted longer survival time will have a lower overall risk.

  • An individual’s survival distribution: \[X \to \mathcal{S}\] where \(\mathcal{S}\) comes from the convex set of distributions on \(T\): \(\mathcal{S} \subseteq Distr(T)\) . It is a probabilistic kind of tasks. Actually, using different models we can obtain different representations of the survival distribution (survival, or hazard, or cumulative hazard function), but as we know, they are all related to each other. The above prediction tasks can all be solved using the survival distribution prediction.


2.3.1 Survival regression models


All survival regression models can be divided into two main groups, depending on their assumptions of the nature of covariates’ effects, namely: accelerated failure time and proportional hazard models.


1) Accelerated failure time (AFT) models:

Predictors act multiplicatively on the failure time or additively on the log failure time: \[logT = X\beta + \sigma\epsilon \Rightarrow T = exp(X\beta)T_0\] where X is a \(n*m\) matrix of m characteristics for n individuals (may include an intercept), \(\beta\) - a vector of m coefficients (weights), \(epsilon\) - an error term with some prespecified distribution \(\psi\) (see below) or estimated non-parametrically (rarely used), \(\sigma\) - a scale parameter, \(T_0 = exp(\sigma\epsilon)\) (Rodríguez 2007; Harrell 2015).

The value of \(exp(X\beta) = \gamma\) is called an accelerator factor, because the effect of a predictor in AFT models is to alter the rate at which a subject proceeds along the time axis (i.e., to accelerate the time to failure) (Harrell 2015). For example, the failure time for the i-th individual with the j-th characteristic by one unit higher than for the k-th subject (all other things being equal) will be \(exp(\beta_j)\) times the failure time of the k-th subject.

Using AFT models, we can easily predict different representations of the individual survival distribution – of course, if such nature of variables effect on time is typical for the field of study. In addition, AFT models are usually parametric and require accurate assumptions about errors distribution (\(\psi\)) (for details see Abdul-Fatawu (2020) and Harrell (2015)). We do not plan using them in this study.


2) Proportional hazards (PH) models:

While AFT models assumes that predictors have a multiplicative effect on the time to event and additive effect for the log time to event, proportional hazards models assumes the same for the hazard and log hazard, accordingly (Rodríguez 2007): \[h(t|X) = h_0(t)exp(X\beta) \Rightarrow log(h(t|X)) = log(h_0(t)) + X\beta\] where \(h_0(t)\) is the baseline hazard function (when all predictors equal zero), \(X\) - matrix of features (not including a constant!), \(\beta\) - regression coefficients (weights). Here \(exp(x_i^`\beta)\) is the relative risk associated with the set of characteristics x_i for the i-th individual in comparison with the risk of an individual with all characteristics equal zero. So in contrast to AFT models, the hazard is positively related to the risk component, \(exp(X\beta)\).

The alternative way to specify a PH regression is to add a constant to feature matrix, such that: \[h(t|X) = exp([C|X]\beta)\] then baseline hazard will be estimated as \(h_0(t)=exp(\beta_0)\), where \(\beta_0\) is the weight for the intercept.

It is important to emphasize that the relative risk does not depend on time, i.e. it is constant in time for the same pair of values of any feature, so hazards are proportional independent of time.

The same holds for cumulative hazards (Rodríguez 2007): \[H(t|X) = H_0(t)exp(X\beta)\]

The corresponding survival function will be (Rodríguez 2007): \[S(t|X) = S_0(t)^{exp(X\beta)}\] where \(S_0(t)\) is the baseline survival function. Thus, the effect of the set of characteristics \(x_i\) for the i-th individual on the survivor function is to raise it to a power given by the relative risk \(exp(x_i^`\beta)\).

We can use some specific distribution for the baseline survival or hazard function – and then we get a parametric PH model, or do not make such assumptions about the baseline functions – and then get a semi-parametric PH model, often called a Cox PH model after its author.

Concerning parametric PH models, it is worth noticing that there are two distributions which give regression models that are out of both AFT and PH families. They are Weibull and its special case - exponential distributions. So the results of Weibull and exponential regression models can be interpreted in both ways.

In other cases if we plot a log hazard on the log time scale the effect of the PH assumption is to model the change in hazard as a vertical shift, while the effect of the AFT assumption is to model the change in hazard as a horizontal shift (Breheny n.d.a). This approach could be of use in choosing between AFT and PH models, as well as in assessing the fulfillment of their assumptions.

Below a summary description of both AFT and PH parametric models is presented (Mills 2010).

As usual, parametric PH models, require correct specification of distributions for baseline survival, hazard or cumulative hazard function. The model parameters are estimated with maximum likelihood estimation, with the likelihood dependent on the chosen distribution with a possibility of treating the distribution choice as a model hyper-parameter. In the output we get a continuous survival distribution prediction.

Semi-parametric (Cox) PH models are more widely used than parametric PH models, because they do not have parametric assumptions concerning baseline functions - just the effect of covariates on hazard. Cox argued that when the PH assumption holds, information about baseline hazard function is not very useful in estimating the parameters of primary interest (\(\beta\)). By special conditioning in formulating the log likelihood function, Cox showed how to derive a valid estimate of \(\beta\) that does not require estimation of baseline hazard as it is dropped out of the new likelihood function. Cox’s derivation focuses on using the information in the data that relates to the relative hazard function \(\exp(X\beta)\) (Harrell 2015).

As a result, estimation of the Cox PH model is based on the maximization of the partial likelihood: \[L(\beta) = \prod_{i: Y_i uncensored}\frac{exp(X_i\beta)}{\sum_{Y_j \geq Y_i}exp(X_j\beta)}\] or, equivalently, minimization of the negative log partial likelihood \(-logL\). Newton-Rapshon technique for iterative estimation is often used here.

After estimating the Cox PH we can get a linear prediction for each individual as \(x_i^`\hat\beta\) and use it to rank them by relative risk of event.

The Cox model and parametric survival models differ markedly in how one estimates \(S(t|X)\). Since the Cox model does not depend on a choice of the underlying survival function \(S(t)\), fitting a Cox model does not result directly in an estimate of \(S(t|X)\). Several authors have derived secondary estimates for it (for more details see Harrell (2015)). Usually they use some non-parametric estimator for the baseline hazard function. As a result, they give a discrete survival distribution prediction. In contrast, there are models that use splines to fit the baseline hazard function (\(log(H_0(t)\) is estimated by natural cubic splines with coefficients fit by maximum likelihood estimation (Royston and Parmar 2002)). Parametric models give direct functional form for baseline survival function, from which individual survival distribution can easily be obtained using the formula above.


3) Extensions of the basic AFT and PH models:

There are several ways to extend the described basic AFT and PH models to receive more accurate predictions. Although most of them are easier incorporated into PH models.

The common problem for all PH models is a violation of the PH assumption (several methods can be used to check this). If AFT models are not the appropriate alternatives, then stratified Cox regression can be estimated (it assumes different baseline hazard functions for different strata) or Cox regression with time-dependent effects (in this case the assumption of PH is dropped out for some or all covariates: \(h(t|X) = h_0(t)exp(X\beta(t))\)). Different parametric assumptions can be put on \(\beta(t)\) and according to them it can be modeled as a piecewise constant, linear, piecewise linear, spline, polynomial function of time and other approaches (for more details see Breheny (n.d.b)).

Another extension aimed at receiving more accurate predictions is a regularizarion of regression coefficients. Here we can apply any of L1 or L2 or elastic net or other regularization methods (for details see (Simon et al. 2011; Wang, Li, and Reddy 2019)). If there are such covariates that should remain in the model even after regularization, then CoxBoost can be applied (Wang, Li, and Reddy 2019).

At last, we can add time-dependent covariates to Cox regression, i.e. such features which are estimated not only at the time origin, but also after it, during the process of observation. In our case it will be all characteristics of users’ behavior obtained after beginning of the subscription period.

PH models with time-dependent covariates will have the following form: \[h(t, X(t)) = h_0(t)exp(X(t)\beta)\]

Calculation of survival functions in this model is a little bit more complicated, because we need to specify a path or trajectory for each variable (Rodríguez 2007). Usually it is done by splitting the whole period of time into intervals between time points at which the corresponding covariate was measured (or changed its value) - thus, instead of one observation we obtain several pseudo-observations for one individual of type \((time\text{_start}, time\text{_stop}, event)\), which means that at the end of the interval \((time\text{_start}, time\text{_stop}]\) an individual has an \(event\) status (0 or 1), and all time-dependent covariates are measured as of at the end of this interval. Thus, in fact, all these pseudo-observations (except, maybe, the first one) become left-truncated right-censored.

This data form is called an Andersen-Gill reformulation or a counting process form and does not change the form of partial likelihood function - only the interpretation of its elements and results of its maximization (T. Therneau, Crowson, and Atkinson 2017). Baseline functions in this framework are calculated for subject with all covariates equal zero for all t and no longer have a clear interpretation. \(exp(x_{ij}(t)\beta_j)\) for some i-th individual and j-th feature is interpreted as the relative risk of event for the individual with the value of \(x_{ij}\) obtained by time t in comparison with the individual who had a zero value of this feature by the same time. With such interpretation we can say that it is more useful for categorical predictors, while for continuous ones it may require some preliminary centralization of values for all subjects and points.

Non-parametric estimates, including Kaplan-Meier curves are not applicable for time-dependent covariates, as they preserve splitting observations into groups as they are made at the starting point of time. Simon and Makuch proposed a technique that evaluates the covariate status of the individuals remaining at risk (i.e. without event or censoring by that point) at each event time, and their estimator can be used to correctly plot Kaplan-Meier curves for time-dependent variables (although, again, only for categorical ones) (Schultz, Peterson, and Breslau 2002).

Predicting survival function with several covariates, one or more of which are time-dependent, is quite more complicated – both in estimation, prediction and interpretation. Laine and Reyes (2014) describes the process of estimating survival after obtaining hazards from Cox PH regression: for particular covariate trajectory \(x^*(t)\)

\[\hat{H}(t|x^*(t)) =\sum_{i=1}^n \int_0^t\frac{exp(\hat{\beta}x^*(u))dN_i(u)}{\sum_j Y_j(u)exp(\hat{\beta}X_j(u))}\] where \(N_i(t)=\mathbb{I}(T_i \geq t)\). Then

\[\hat{S}(t|x^*(t))=exp(-\hat{H}(t|x^*(t)))\] and now it is not a simple function, as now it requires unique integration to estimate cumulative hazard for every value of \(x^*(t)\).

In addition, Laine and Reyes (2014) note that in the presence of time-dependent covariates, it may not make sense to calculate survival probabilities or predictions. This requires knowledge of \(X(t)\), which may be unknown until time t, at which point its observation frequently implies survival. For example, individuals who have surgery at 1 year are alive at 1 year, by definition. Hence, survival estimation is rarely implemented in this case in medical studies.

Although, in our opinion, it could be useful in the process of observation (after the start of therapy, subscription or other starting points depending on the nature of the study), as we obtain additional information, and using it we can update survival function for individuals or groups. For example, Yao et al. (2020) illustrate how updated and not updated survival curves can be different:


2.3.2 Machine learning models


Machine learning models for survival data in comparison with the same models for “ordinary” continuous or class dependent variables can deal with censored data. Here we have survival trees and their ensembles, including random survival forest (RFS), gradient boosting machines, support vector machines, and neural networks. Usually they do not have such restrictive assumptions as statistical models, although there are some models that are based on the PH assumption. The comprehensive survey of different machine learning models for survival data is presented in Sonabend (2021). The following table was taken from there:

Class: RSF – Random Survival Forest and Decision Trees; GBM – Gradient Boosting Machine; SVM – Support Vector Machine; ANN – Artificial Neural Network. Task: Det. - Deterministic, Prob. - Probabilistic, Rank - Ranking.


The thing is, that for ML models it is still not very common to be adapted for use of the time-dependent covariates. As one of the purposes of our study was to predict customers’ churn with time-varying features from all the variety of ML models we choose random survival forest (RSF), for which, unlike other machine learning methods, the algorithm can be more easily adapted both for working with survival data and for using time-dependent covariates.

As Sonabend (2021) notes, the difference between ordinary random forest and random survival forest lies in different choices of splitting rules and terminal node predictions – all fundamentals remains unchanged. He distinguishes two groups of RSF models by splitting rules being used: those relied on hypothesis tests (primarily the nonparametric log-rank test, which in statistical analysis is used for comparing survival distributions; higher values of the log-rank statistic mean greater dissimilarity between distributions) and those utilise likelihood-based measures (they require likelihood estimation, which means assumption about certain model form, for example Cox PH model).

As far as terminal node predictions are concerned, all RSF models again can be separated into two groups: one of them result in terminal node ranking predictions (these models assume using likelihood-based splitting rule and a PH model form), the other predict survival distribution by estimating the survival function, using the Kaplan-Meier or Nelson-Aalen estimators, on the sample in the terminal node.

There are several approaches to adjust RSF to time-dependent covariates. Overall, they suggest splitting pseudo-observations for one subject across many nodes as a function of time, so in the end, a subject could end up in many different terminal nodes, but at any given time, each subject can be classified into one and only one terminal node (Bou-Hamad, Larocque, and Ben-Ameur 2011).


2.4 Model evaluation


The specific feature for all survival models is that the predicted outcome for the i-th individual – the expected time to event (\(\hat{T_i}\)) or relative risk ((\(\hat{\eta_i}\)) or survival distribution ((\(\hat{S_i}\)) – is a different object than the outcome used for model training – the pair of \((Y_i,e_i)\), as they were defined above. It affects the approaches to model evaluation.

Depending on the task (what is predicted by the model) different loss or score functions can be used (@ Sonabend 2021):

1) For time-to-event predictions:

For this task we can use the same metrics as for regressions with continuous target variable (MSE, RMSE, MAE), except the fact that they are evaluated only for uncensored individuals, i.e. for those for whom the time to event is observable. Of course, they are not very optimal when there is many censored observations.


2) For continuous relative risk (ranks) predictions and time-to-event predictions:

For this kind of tasks usually measures of discrimination are used. It means that the predicted values are considered as a measure of relative risk of event (although not interpretable), according to which observations are ranked, and then these ranks are used to calculate the measure of discrimination. It makes this measure independent of outliers. A model has perfect discrimination if it correctly predicts that patient with higher risk of event will have earlier event.

Here are two types of discrimination measures (Sonabend 2021):

  • Concordance index (or C-index, or simply C) measures the proportion of cases in which the model correctly separates a pair of observations into low and high risk. There are several form of concordance indices, but the most widely used is Harrell’s C:

    For a pair of observations i and j (\(i \neq j\)), arranged in ascending order by the values of Y ($Y_i < \(Y_j\)) with predicted relative risk values \(\eta_i\) and \(\eta_j\), and censoring statuses \(e_i\) and \(e_j\):

    • If both \(T_i\) and \(T_j\) are not censored (\(Y_i = T_i\), \(Y_j = T_j\)), the pair (i, j) is a concordant pair if \(\eta_i > \eta_j\), and it is a discordant pair if \(\eta_i < \eta_j\).
    • If both \(T_i\) and$ \(T_j\) are censored (\(Y_i = C_i\), \(Y_j = C_j\)), this pair is not considered in the computation.
    • If one of \(T_i\) and \(T_j\) is censored, we only observe one event. If we observe the event for individual i, so \(Y_i = T_i\) and \(Y_j = C_j\), then (i, j) is a concordant pair, otherwise we don’t consider this pair in the computation. \[c = \frac{\text{# concordant pairs}}{\text{# concordant pairs + # disconcordant pairs}} = \frac{\sum_{i \neq j}\mathbb{I}(\eta_i < \eta_j) \mathbb{I}(T_i > T_j)e_j}{\sum_{i \neq j} \mathbb{I}(T_i > T_j)e_j}\] \(c \in [0,1]\) with 1 indicating perfect separation, 0.5 indicate that the risk score predictions are no better than a coin flip in determining which individual will have longer time to event, and 0 being separation went in the opposite direction.

    C-index can be reformulated for time-to-event prediction (just replace \(\eta_i\) and \(\eta_j\) with \(\hat{T_i}\) and \(\hat{T_j}\) and change the signs between \(\eta\)s to the opposite ones).

    The other variants of the c-index are different from the Harrell’s C in using some weights in nominator and denominator, calculated based on the non-parametric estimates for survival distribution or survival function for censoring distribution (e.g. Uno’s C-index). They are less affected by censoring than the Harrell’s C, but also less interpretable.

  • Time-dependent AUC and C-index:

    AUC measures for survival analysis have been developed in order to provide a time-dependent measure of discriminatory ability (for more detail see Wang, Li, and Reddy (2019)).


3) For survival distribution prediction:

For this type of tasks several types of measures can be used:

  • Calibration measures, including graphical methods of comparison between predicted and expected distribution (the latter is often approximated by the Kaplan-Meier curve) and modification of \(\chi^2\) test.

  • Scoring rules (loss functions) evaluate probabilistic predictions and (attempt to) measure the overall predictive ability of a model, i.e. both calibration and discrimination.

    Losses in the survival setting compare predicted survival distributions to the observed pairs (\(Y_ic, e_i\)). Usually it also include an estimator of the unknown censoring distribution.

    The most widely used loss here is Brier score (BS): \[BS(t) = \frac{1}{N}\sum_{i=1}^N w_i(t)(\hat{S_i}(t) - S_i(t))^2\] where \(S_i(t)\) is the true survival probability for i-th individual at time t (it equals 0 if \(T_i \leq t\) and 1 otherwise), \(\hat{S_i}(t)\) - predicted survival probability, \(w_i(t)\) - the weight of i-th observation, estimated by incorporating the Kaplan-Meier estimator of the censoring distribution G obtained on the given dataset \((X,Y,1−e)\) (Wang, Li, and Reddy 2019): \(w_i(t) = e_i/G(Y_i)\) if \(Y_i \leq t\) or \(1/G(Y_i)\) otherwise. The weights for the instances that are censored before t will be 0. However, they contribute indirectly to the calculation of the Brier score since they are used for calculating G.

    Integrated Brier score (IBS) is an overall measure for the prediction: \[IBS = \frac{1}{Y_{max}} \int_0^{Y_{max}} BS(t)dt\]


4) For models with time-dependent covariates:

It is a more complicated procedure to estimate prediction performance of models with time-dependent covariates, but here we still can calculate C-indices, as well as time-dependent Brier score and integrated Brier score. You can find more information about it in Terry M. Therneau and Watson (2017).


3 Exploratory data analysis


The whole dataset consists of several files (total size is 32 GB):

  • transactions (transactions.csv for period up to March, 2017 and transactions_v2.csv for period up to April, 2017) - information about buying a subscription, its renewal (auto or manual) and cancellation (manual, by client); we will define our dependent variables (churn and time to churn) using data from these files;

  • members’ characteristics (members_v3.csv) - information about city, age, gender, registration method and date of initial registration in the service,

  • user logs (user_logs.csv for period up to March, 2017 and user_logs_v2.csv for period up to April, 2017) - daily user logs describing listening behaviors of users.

As organizers answered to one of participants’ questions at Kaggle, sometimes they erase members information, but mostly do not erase data from transactions and user logs. Taking it into consideration we’ve decided to choose data only for those users, whose ID occurs in all three files - we’ve found almost 2 million (1 937 764) such IDs.

Further, for more convenience and for initial analysis we’ve decided to sample 1% out of these IDs and exclude those of them for whom all dates in user logs are missing. After that 19200 users remained in our dataset.


3.1 Transactions


Data for 198566 transactions from 2015-01-01 till 2017-03-31. Transactions include subscription, its renewal, changes in subscription plans and cancellations, although there is not any indicator for distinguishing these types of transactions from each other.

trans <- trans %>%
  arrange(msno, transaction_date) %>%
  group_by(msno) %>%
  mutate(time_bw_exp_trans = as.numeric(membership_expire_date - transaction_date, 'days'),
         time_bw_exp_nexttrans = ifelse(is.na(lead(transaction_date)),
                                        as.numeric(as.Date("2017-03-31") - membership_expire_date, 'days'),
                                        as.numeric(lead(transaction_date) - membership_expire_date, 'days')),
         time_bw_trans = as.numeric(lead(transaction_date) - transaction_date, 'days'),
         trans_id = 1:n()) %>%
  ungroup()

trans_df <- trans %>%
  arrange(msno, transaction_date) %>%
  mutate(membership_expire_date_cor = as.Date(ifelse(membership_expire_date < transaction_date, 
                                                     transaction_date, membership_expire_date),
                                              "1970-01-01")) %>%
  group_by(msno) %>%
  mutate(time_bw_exp_trans = as.numeric(membership_expire_date_cor - transaction_date, 'days'),
         time_bw_exp_nexttrans = ifelse(is.na(lead(transaction_date)),
                                        as.numeric(as.Date("2017-03-31") - membership_expire_date_cor, 'days'),
                                        as.numeric(lead(transaction_date) - membership_expire_date_cor, 'days'))) %>%
  ungroup()


Transaction file contains the following information:


  • transaction_date - the date of transaction.
trans_stat <- trans %>%
  group_by(msno) %>%
  summarise(n = n(),
            sum_cancel = sum(is_cancel),
            sum_auto = sum(is_auto_renew),
            ) %>%
  ungroup()

trans_freq_max <- trans %>%
  filter(!is.na(time_bw_trans)) %>%
  group_by(time_bw_trans) %>%
  tally() %>%
  arrange(-n)

p1 <- trans_stat %>%
  filter(n < quantile(trans_stat$n, 0.99)) %>%
 ggplot() +
  aes(x = n) +
  geom_histogram(binwidth = 1, fill = "#4682B4") +
  geom_text(x = 15, y = 4000, label = sprintf("Mean (SD): %.1f (%.1f)", mean(trans_stat$n),
                                              sd(trans_stat$n)),
            hjust = 0, color = "grey20", size = 3.5) +
  geom_text(x = 15, y = 3700, label = sprintf("Median (Q1, Q3): %.0f (%.0f, %.0f)", 
                                              median(trans_stat$n), quantile(trans_stat$n, 0.25),
                                              quantile(trans_stat$n, 0.75)),
            hjust = 0, color = "grey20", size = 3.5) +
  geom_text(x = 15, y = 3400, label = sprintf("Min-Max: %.0f-%.0f", min(trans_stat$n),
                                              max(trans_stat$n)),
            hjust = 0, color = "grey20", size = 3.5) +
  labs(x = "", y = "", title = "Number of records per ID in transactions\n(for values < 99th percentile)") +
  scale_x_continuous(breaks = seq(0, 30, 5), expand = c(0,0),
                     limits = c(0, 30)) +
  scale_y_continuous(expand = c(0,0), 
                     breaks = seq(0, 8000, 1000), labels = c(0, paste0(seq(1,8,1), "K"))) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12))

p2 <- trans %>%
  filter(time_bw_trans < quantile(trans$time_bw_trans, 0.99, na.rm = TRUE)) %>%
 ggplot() +
  aes(x = time_bw_trans) +
  geom_histogram(binwidth = 5, fill = "#4682B4") +
  geom_text(x = 90, y = 140000, label = sprintf("Mean (SD): %.1f (%.1f)", mean(trans$time_bw_trans, na.rm = TRUE),
                                              sd(trans$time_bw_trans, na.rm = TRUE)),
            hjust = 0, color = "grey20", size = 3.5) +
  geom_text(x = 90, y = 130000, label = sprintf("Median (Q1, Q3): %.0f (%.0f, %.0f)",
                                              median(trans$time_bw_trans, na.rm = TRUE), 
                                              quantile(trans$time_bw_trans, 0.25, na.rm = TRUE),
                                              quantile(trans$time_bw_trans, 0.75, na.rm = TRUE)),
            hjust = 0, color = "grey20", size = 3.5) +
  geom_text(x = 90, y = 120000, label = sprintf("Min-Max: %.0f-%.0f", min(trans$time_bw_trans, na.rm = TRUE),
                                              max(trans$time_bw_trans, na.rm = TRUE)),
            hjust = 0, color = "grey20", size = 3.5) +
  labs(x = "", y = "", title = "Time between transactions\n(for values < 99th percentile)") +
  scale_x_continuous(breaks = seq(0, 210, 30), expand = c(0,0)) +
  scale_y_continuous(expand = c(0,0), limits = c(0, 150000),
                     breaks = seq(0, 150000, 50000), 
                     labels = c(0, paste0(seq(50,150,50), "K"))) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12))

ggarrange(p1, p2, nrow = 1, ncol = 2)


– About 22.7% users have only one transaction in the whole dataset.

– Usually there is about 30 days between two conseсutive transactions of one user. Particularly, among all pairs of consecutive transactions about 25.1% are distanced by 30 days, and 45.9% - by 31, all other periods occur for less than 7% times each.


  • membership_expire_date - subscription expiration date after this transaction. There were 1415 records with membership expiration date less than transaction date. The minimum expiration date for them was 1970-01-01, for 96 transactions it went before 2015-01-01 (the minimum transaction date in the whole dataset), so there are obviously some odd values here, and we should do something with them prior to the analysis.

– Out of all negative values for the time between membership expire date and transaction date listed for the same transaction 92.9% were between -80 and 0 days; moreover, 88.3% equals -1 or -2 days, and 94.8% of all transactions with negative time are with is_cancel=1.

– There were no specific explanations for these cases from organizers, but too negative values look very odd, especially taking into account that most of negative time transactions (71.2%) were not the last for the corresponding user. So we have decided that it will be reasonable to replace membership expiration date in all such transactions into the date of the corresponding transaction: it will hardly bias some further results, as the number of such transactions is quite small, and most of them contains almost equal transaction and expiration dates.


  • is cancel - whether or not the user cancelled the membership in this transaction. Among all records 3.1% concerned cancellation. According to organizers, subscription cancellation does not imply the user has churned, as he/she may cancel subscription due to change of service plans or other reasons. Among 19200 users 27% cancelled their subscription for at least once.


  • is_auto_renew - for transactions not concerning cancellation whether it was auto or manual renewal of subscription. Among all records 83.3% are auto-renewals. All cancellation transactions are marked as auto-renewals, among non-cancellation transactions 82.7% are auto-renewals. Organizers did not give any clarification about what auto-renewal means, but as we will see below, most likely, auto-renewal means automotive renewal of the membership expiration date: for example, when a user cancels his/her subscription or changes a plan from, say, 1-month to 2-months, then membership expiration date changes automatically.

    Among 19200 users 43.2% had only manual renewals, 48.1% had only auto-renuals, the rest (8.7%) have both of them.


  • payment_method_id - payment method. Organizers at Kaggle did not provide any information about these methods. There were 35unique numeric codes for them.
plt_df <- trans_df %>%
  group_by(payment_method_id) %>%
  tally() %>%
  ungroup() %>%
  mutate(pct = n/sum(n)*100,
         lbl = sprintf("%.0f%%", pct),
         lbl = ifelse(lbl == "0%", "<1%", lbl))

trans_df %>%
 ggplot() +
  aes(x = payment_method_id) +
  geom_bar(fill = "#2D559E") +
  geom_text(aes(x = payment_method_id, y = n - 250, label = lbl),
            plt_df %>% filter(pct > 4), angle = 90, hjust = 1, vjust = 0.5,
            color = "white", fontface = "bold") +
  labs(title = "Payment method", y = "", x = "") +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0,100000,20000),
                     labels = c(0, paste0(seq(20,100,20), "K")),
                     limits = c(0, 100000)) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 10),
        axis.ticks.x = element_blank())


  • payment_plan_days - length of membership plan in days. Not very informative variable, because in most cases it equals 30 days, it is often positive for cancellations and when summed with transaction date does not give membership expiration date. We do not plan to use it.
plt_df <- trans_df %>%
  mutate(payment_plan_days = fct_infreq(factor(payment_plan_days))) %>%
  group_by(payment_plan_days) %>%
  tally() %>%
  ungroup() %>%
  mutate(pct = n/sum(n)*100,
         lbl = sprintf("%.0f%%", pct)) %>% 
  arrange(-n)

trans_df %>%
 ggplot() +
  aes(x = fct_infreq(factor(payment_plan_days))) +
  geom_bar(fill = "#2D559E") +
  geom_text(aes(x = payment_plan_days, y = n - 250, label = lbl),
            plt_df %>% filter(pct > 5), angle = 90, hjust = 1, vjust = 0.5,
            color = "white", fontface = "bold") +
  geom_text(aes(x = payment_plan_days, y = n, label = lbl),
            plt_df %>% filter(pct > 3 & pct < 80), angle = 90, hjust = 0, vjust = 0.5,
            color = "grey20", fontface = "bold") +
  labs(title = "Payment plan days", y = "", x = "Unique values of payment plan days") +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0,175000,25000),
                     labels = c(0, paste0(seq(25,175,25), "K")),
                     limits = c(0, 175000)) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 10),
        axis.ticks.x = element_blank())

– For cancellation transactions there are only 3 available values of payment plan days: 0, 30 and 31.


  • plan_list_price and actual_amount_paid - both have 39 unique values. For 6.3% of all transactions these values differ. This share is a bit higher for cancellations: 12.9% versus 6.1%.
ggplot(trans_df %>% mutate(is_cancel = factor(is_cancel, labels = c("is_cancel=0", "is_cancel=1")))) +
  aes(x = plan_list_price, y = actual_amount_paid) +
  geom_point(aes(color = is_cancel), shape = "circle", size = 2, alpha = 0.5) +
  scale_color_brewer(palette = "Set1", direction = -1) +
  labs(x = "Plan list price", y = "Actual amount paid") +
  facet_rep_wrap(~ is_cancel, scales = "free") +
  theme_minimal() +
  theme(legend.position = "none")


3.1.1 Churn definition


We will use a definition of churn given by organizers. They define it as a membership expiration such that during the next 30 days it is not reactivated. So the key variable for revealing churns is a period between a membership expiration date given for this particular transaction and the date of the next transaction.

Our approach to churn definition will be based on the following assumptions:

  • As all data for transactions were extracted on March 31, 2017, for all users who had the last transaction before that day we will suppose that there were no more transactions till March 31, 2017.

  • For each user the date of his/her first available transaction in the dataset will be considered as the start of his/her subscription (membership), because we have no earlier data for transactions for those users who registered at KKBox before 2015. In other words, in the analysis of churn we will consider data as if this subscription service was opened at 2015-01-01. Although, in fact, it would be correct to say that for users who registered at the service before 2015 (54.2% of all users), the first time period of their membership should be considered as left censored, because we have no information about it actual date of beginning. But survival models with interval (both left and right) censoring are quite complicated, especially for machine learning (i.e. prediction), that is why for this particular task we decided to follow this simplification.


Starting from the first transaction for each user we can calculate time between membership expiration date in every transaction and the next transaction date (or if there is no further transactions for this user - till March 31, 2017) - time_bw_exp_nexttrans.

– There are negative values. Most likely, these are cases when user changes his/her subscription plan (including cancellation, which is always marked as auto-renewals) before the expiration date given in the previous subscription. Usually it happens in auto-renewed transactions and 1-6 days prior to expiration of the previous plan.

– For the most of consecutive transactions the date of expiration for the first of them coincide with transacton date for the second one, or the latter goes 1 day prior or later than the former.


According to the time between expiration date and the date of the next transaction (or 2017-03-31) we will place a churn indicator for each transaction (i.e. whether a user churned after the end of membership expiration listed in this transaction or not) as follows:

  • If this time exceeds 30 days, then we will consider this user as churned at this membership expiration date (churn = 1).

  • If this time is 30 days or less, then we can say that during the time from this till the next transaction (or till 2017-03-31 for the last user’s transaction) this user did not churn (churn = 0).


Scheme for logic of churn labelling:

# churn
trans_df <- trans_df %>%
  mutate(churn = as.numeric(time_bw_exp_nexttrans > 30))

subscr_periods <- function(churn_seq) {
  cumsum(ifelse(is.na(lag(churn_seq)), 1, lag(churn_seq) == 1))
}

# sequences of churn indicators by msno
trans_periods <- trans_df %>%
  group_by(msno) %>% 
  group_modify(~ subscr_periods(.x$churn) %>%
                 tibble::enframe(name = "trans_id",
                                 value = "msno_subscr_id"))

# joining each transaction with id for sequence of churn indicators
trans_df <- trans_df %>%
  left_join(trans_periods, by = c("msno", "trans_id"))

trans_df <- trans_df %>%
  mutate(subscr_id = group_indices(trans_df, msno, msno_subscr_id)) 

mode <- function(x){
  if (sum(is.na(x)) == length(x)) {
    NA
  } else {
    which.max(tabulate(x))
  }
}

# combining transactions in one sequence into periods
trans_periods <- trans_df %>%
  group_by(msno) %>%
  mutate(time_from_first_trans = as.numeric(transaction_date - transaction_date[1], 'days'),
         time_from_prev_trans = replace_na(as.numeric(transaction_date - lag(transaction_date), 'days'), 0),
         time_from_prev_exp = replace_na(as.numeric(transaction_date - lag(membership_expire_date_cor), 'days'), 0)) %>%
  ungroup() %>%
  group_by(subscr_id) %>%
  summarise(msno = unique(msno),
            msno_subscr_id = unique(msno_subscr_id),
            churn = max(churn),
            start_date = transaction_date[1],
            end_date = pmin(as.Date("2017-03-31"), membership_expire_date_cor[n()]),
            n_trans = n(),
            n_cancel = sum(is_cancel),
            num_auto_renew = sum(is_auto_renew),
            first_payment_method_id = payment_method_id[1],
            first_actual_amount_paid = actual_amount_paid[1],
            first_plan_list_price = plan_list_price[1],
            mode_payment_method_id = mode(as.numeric(as.character(payment_method_id))),
            sum_actual_amount_paid = sum(actual_amount_paid),
            churned_before = as.numeric(msno_subscr_id > 1),
            churned_before_num = msno_subscr_id - 1,
            time_from_first_trans = time_from_first_trans[1],
            time_from_prev_trans = time_from_prev_trans[1],
            time_from_prev_exp = time_from_prev_exp[1]) %>%
  ungroup() %>%
  mutate(time = as.numeric(end_date - start_date, 'days') + 1)

trans_periods <- trans_periods %>%
  mutate(mode_payment_method_id = fct_lump_min(factor(mode_payment_method_id, as.numeric(levels(trans_df$payment_method_id))),
                                               min = 30, other_level = "999"),
         first_payment_method_id = fct_lump_min(factor(first_payment_method_id, as.numeric(levels(trans_df$payment_method_id))),
                                                min = 30, other_level = "999"))

members <- members %>%
  left_join(trans_df %>%
              group_by(msno) %>%
              summarise(transaction_date_first = min(transaction_date),
                        transaction_date_last = max(transaction_date),
                        membership_expire_date_last = min(membership_expire_date_cor[n()], as.Date("2017-03-31")),
                        churned_num = sum(churn == 1),
                        churned_last = as.numeric(churn[n()] == 1),
                        churned_ever = as.numeric(churned_num > 0),
                        time_obs = as.numeric(membership_expire_date_last - transaction_date_first, 'days'),
                        n_trans_tot = n()),
            by = "msno") %>%
  left_join(trans_periods %>% group_by(msno) %>% summarise(n_subscr = n()) %>% ungroup(),
            by = "msno") %>%
  mutate(time_tot = as.numeric(membership_expire_date_last - registration_init_time, 'days') + 1)

Out of all transactions 7.4% were indicated as ended with a user’s churn.

Out of all unique users in our data set 60.2% had at least one churn. 44% of all unique users had churn after their last transaction, 16.2 had at least one churn but then they resubscribed to the service.


plt_df <- members %>%
  group_by(churned_num) %>%
  tally() %>%
  ungroup() %>%
  mutate(pct = n/sum(n)*100,
         lbl = sprintf("%.0f%%", pct),
         lbl = ifelse(lbl == "0%", "<1%", lbl))

p1 <- members %>%
 ggplot() +
  aes(x = churned_num) +
  geom_bar(fill = "#2D559E") +
  geom_text(aes(x = churned_num, y = n - 500, label = lbl), 
            plt_df %>% filter(pct > 2.5), 
            color = "white", fontface = "bold") +
  geom_text(aes(x = churned_num, y = n + 500, label = lbl), 
            plt_df %>% filter(pct < 2.5), 
            color = "grey20", fontface = "bold") +
  labs(title = "Number of churns by one user", y = "", x = "") +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0,10000,2000),
                     labels = c(0, paste0(seq(2,10,2), "K")),
                     limits = c(0, 10000)) +
  scale_x_continuous(expand = c(0,0), breaks = plt_df$churned_num) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12),
        axis.ticks.x = element_blank())

plt_df <- members %>%
  group_by(n_subscr) %>%
  tally() %>%
  ungroup() %>%
  mutate(pct = n/sum(n)*100,
         lbl = sprintf("%.0f%%", pct),
         lbl = ifelse(lbl == "0%", "<1%", lbl))

p2 <- members %>%
 ggplot() +
  aes(x = n_subscr) +
  geom_bar(fill = "#2D559E") +
  geom_text(aes(x = n_subscr, y = n - 750, label = lbl), 
            plt_df %>% filter(pct > 3.5), 
            color = "white", fontface = "bold") +
  geom_text(aes(x = n_subscr, y = n + 750, label = lbl), 
            plt_df %>% filter(pct <= 3.5), 
            color = "grey20", fontface = "bold") +
  labs(title = "Number of subscription periods by user", y = "", x = "") +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0,15000,2500),
                     labels = c(0, paste0(seq(2.5,15,2.5), "K")),
                     limits = c(0, 15000)) +
  scale_x_continuous(expand = c(0,0), breaks = plt_df$n_subscr) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12),
        axis.ticks.x = element_blank())

ggarrange(p2, p1, nrow = 1, ncol = 2)


Therefore, for all transactions of each user we obtained the sequence of churn indicators (0 or 1). Then we will combine them into periods of uninterrupted membership, i.e. periods inside of which there were no churns (see the Scheme above):

  • If churn is zero for several consecutive transactions of one user, then we can combine the corresponding membership intervals into one, including the last period in this sequence if it has a churn status of 1. If the sequence ends with zero churn (it is possible only when the last user’s transaction has zero churn status) then overall churn for combined intervals will be 0, and time will be a period from the date of the first transaction in this sequence till the minimum of 2017-03-31 and membership expiration date in the last transaction.

  • If the sequence of zeroes ends with 1 or sequence consists of only one transaction with positive churn status (which is possible if this is the first user’s transaction or a transaction preceded by another one with churn = 1), then overall churn for the sequence will be 1, and time to churn will be a period from the date of the first transaction in the sequence till the membership expiration date for the last transaction in the sequence.

In further analysis we will consider this periods of membership by one user as independent from each other, so a user who returned to the service after the churn will be considered as a new client, except the fact that we will know all information about his previous churns, transactions and other behaviour. Such approach is justified by the definition of churn used in KKBox.

Thus, instead of 19200 unique users we will use data on 25353 periods of subscription. Out of them 14603 (57.6%) ended with a churn, the rest are censored.


3.1.2 Descriptive survival analysis


Kaplan-Meier survival curve and spline-estimated hazard curve for KKBox’s users are presented below:

srvfit <- survfit(Surv(time, churn) ~ 1, trans_periods)
med_surv <- surv_median(srvfit)
srvhaz <- bshazard(Surv(time, churn) ~ 1, trans_periods, verbose = FALSE)

p1 <- ggsurvplot(srvfit, size = 1.2, palette = "#006666", conf.int = FALSE,
           ggtheme = theme_minimal() + theme(plot.title = element_text(face = "bold")),
           title = "Probability of active membership (= survival probability)",
           xlab = "Time from the first transaction, months",
           ylab = "Probability of active membership",
           legend = "none", censor = FALSE)

p1 <- p1$plot +
  scale_x_continuous(expand = c(0, 0), breaks = seq(0, 820, 30.4375*2),
                     labels = seq(0, 820/30.4375, 2),
                     limits = c(0, 820)) +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0,1,0.2),
                     labels = scales::percent_format(accuracy = 1)) +
  geom_text(x = 800, y = 0.9, 
            label = sprintf("Median membership time\n(= median survival time):\n%.d days (%.0f months)", med_surv$median, med_surv$median/30.4375),
            hjust = 1, vjust = 1) +
  theme_classic() +
  theme(legend.position = "none",
        plot.title = element_text(face = "bold"),
        panel.grid.major = element_line(size = .2, colour = "grey90"),
        panel.grid.minor = element_line(size = .2, colour = "grey90"))

p2 <- ggplot() +
  geom_line(aes(x = srvhaz$time, y = srvhaz$hazard), color = "#D82632", size = 1.2) +
  scale_x_continuous(expand = c(0, 0), breaks = seq(0, 820, 30.4375*2),
                     labels = seq(0, 820/30.4375, 2), limits = c(0,820)) +
  scale_y_continuous(expand = c(0,0), limits = c(0, 0.0125)) +
  labs(x = "Time from the first transaction, months", y = "Hazard rate", 
       title = "Churn hazard rate") +
  theme_classic() +
  theme(plot.title = element_text(face = "bold"),
        panel.grid.major = element_line(size = .2, colour = "grey90"),
        panel.grid.minor = element_line(size = .2, colour = "grey90"))

ggarrange(p1, p2, nrow = 2, ncol = 1)

– As we can see, the hazard behaviour is closer to the type of early churn when the churn rate is high and sharply falling for the first ~ 2-3 months after subscription initiation, then it stabilizes for some time, have a little peak between 12 and 14 months, and then gradually decreasing.

– There is a noticable drop in active membership at 1 month after subscription - it could be because of users who subscribe just to try a service or maybe there are some trial subscription plans at KKBox.


trans_periods <- trans_periods %>%
  mutate(churned_before_numG = factor(churned_before_num, 0:6,
                                      labels = c(0:3, "4-6", "4-6", "4-6")))

srvfit <- survfit(Surv(time, churn) ~ churned_before_numG, trans_periods)
                  
p1 <- ggsurvplot(srvfit, 
                 palette = paletteer_d("ggsci::light_blue_material")[seq(2,10,2)], 
                 size = 1.2, conf.int = FALSE, 
                 legend.labs = levels(trans_periods$churned_before_numG),
                 legend.title = "",
                 ggtheme = theme_minimal() + theme(plot.title = element_text(face = "bold")),
                 title = "Probability of active membership by\nnumber of previous churns",
                 xlab = "Time from the first transaction, months",
                 ylab = "Probability of active membership",
                 legend = "bottom", censor = FALSE)

p1 <- p1$plot +
  scale_x_continuous(expand = c(0, 0), breaks = seq(0, 820, 30.4375*2),
                     labels = seq(0, 820/30.4375, 2),
                     limits = c(0, 820)) +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0,1,0.2),
                     labels = scales::percent_format(accuracy = 1)) +
  theme_classic() +
  theme(legend.position = c(0.9,0.8),
        legend.background = element_blank(),
        plot.title = element_text(face = "bold"),
        panel.grid.major = element_line(size = .2, colour = "grey90"),
        panel.grid.minor = element_line(size = .2, colour = "grey90"))

plt_df <- NULL

for (lvl in levels(trans_periods$churned_before_numG)) {
  srvhaz <- bshazard(Surv(time, churn) ~ 1, 
                     trans_periods %>% filter(churned_before_numG == lvl), 
                     verbose = FALSE)
  plt_df <- rbind(plt_df,
                  tibble(x = srvhaz$time,
                         y = srvhaz$hazard,
                         lvl = lvl))
}

p2 <- ggplot() +
  geom_line(aes(x = x, y = y, color = lvl), plt_df, size = 1.2) +
  scale_x_continuous(expand = c(0, 0), breaks = seq(0, 820, 30.4375*2),
                     labels = seq(0, 820/30.4375, 2), limits = c(0,820)) +
  scale_y_continuous(expand = c(0,0), limits = c(0, 0.0125)) +
  scale_color_manual(values = paletteer_d("ggsci::light_blue_material")[seq(2,10,2)]) +
  labs(x = "Time from the first transaction, months", y = "Hazard rate", 
       title = "Churn hazard rate by\nnumber of previous churns", color = "") +
  theme_classic() +
  theme(legend.position = c(0.9,0.8),
        legend.background = element_blank(),
        plot.title = element_text(face = "bold"),
        panel.grid.major = element_line(size = .2, colour = "grey90"),
        panel.grid.minor = element_line(size = .2, colour = "grey90"))

ggarrange(p1, p2, nrow = 2, ncol = 1)

– The odd increase in hazard rate between 12 and 14 months is mainly peculiar for users without previous churns. Maybe, it is somehow explained by the characteristics of their subscription plans.

– Hazard rate for users with higher number of churns falls more gradually in the beginning of subscription period, but for all groups we have a huge slump in active membership at 1 months after subscription, and this slump is more considerable with users with more prevous churns.


trans_periods <- trans_periods %>%
  mutate(year_period = factor(year(start_date)),
         month_period = factor(month(start_date), labels = month.abb))

plt_df <- trans_periods %>%
  group_by(year_period) %>%
  tally() %>%
  ungroup() %>%
  mutate(pct = n/sum(n)*100,
         lbl = sprintf("%.1f%%", pct))

trans_periods %>%
 ggplot() +
  aes(x = year_period) +
  geom_bar(aes(fill = year_period)) +
  geom_text(aes(x = year_period, y = n - 500, label = lbl), 
            plt_df %>% filter(year_period != "2015"), 
            color = "white", fontface = "bold") +
  geom_text(aes(x = year_period, y = n - 500, label = lbl), 
            plt_df %>% filter(year_period == "2015"), 
            color = "grey20", fontface = "bold") +
  labs(title = "Year of subscription", y = "", x = "") +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0, 14000, 2000),
                     labels = c(0, paste0(seq(2,14,2), "K")),
                     limits = c(0, 14000)) +
  scale_fill_manual(values = paletteer_c("ggthemes::Blue", 3)) +
  ggthemes::theme_tufte() +
  theme(legend.position = "none",
        plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12),
        axis.ticks.x = element_blank())

srvfit <- survfit(Surv(time, churn) ~ year_period, trans_periods)
                  
p1 <- ggsurvplot(srvfit, 
                 palette = paletteer_c("ggthemes::Blue", 3), 
                 size = 1.2, conf.int = FALSE,
                 legend.labs = levels(trans_periods$year_period),
                 legend.title = "",
                 ggtheme = theme_minimal() + theme(plot.title = element_text(face = "bold")),
                 title = "Probability of active membership by\nyear of subscription",
                 xlab = "Time from the first transaction, months",
                 ylab = "Probability of active membership",
                 legend = "bottom", censor = FALSE)

p1 <- p1$plot +
  scale_x_continuous(expand = c(0, 0), breaks = seq(0, 820, 30.4375*2),
                     labels = seq(0, 820/30.4375, 2),
                     limits = c(0, 820)) +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0,1,0.2),
                     labels = scales::percent_format(accuracy = 1)) +
  theme_classic() +
  theme(legend.position = c(0.9, 0.8),
        legend.title = element_blank(),
        plot.title = element_text(face = "bold"),
        panel.grid.major = element_line(size = .2, colour = "grey90"),
        panel.grid.minor = element_line(size = .2, colour = "grey90"))

p1

– The more considerable attrition during the fisr month of subscription is detected for subscriptions started in 2016, the least one - for 2017 (although in the latter case there were censored observations at 30 and 31 days after the start of subscription - if we knew their real status (observed them for longer period) we could obtain a different form of survival curve for this group of membership periods.


plt_df <- trans_periods %>%
  group_by(month_period) %>%
  tally() %>%
  ungroup() %>%
  mutate(pct = n/sum(n)*100,
         lbl = sprintf("%.1f%%", pct))

trans_periods %>%
 ggplot() +
  aes(x = month_period) +
  geom_bar(aes(fill = month_period)) +
  geom_text(aes(x = month_period, y = n - 250, label = lbl), 
            plt_df, 
            color = "white", fontface = "bold") +
  labs(title = "Month of membership start", y = "", x = "") +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0, 6000, 2000),
                     labels = c(0, paste0(seq(2,6,2), "K")),
                     limits = c(0, 6350)) +
  scale_fill_manual(values = paletteer_c("grDevices::Set 2", 12)) +
  ggthemes::theme_tufte() +
  theme(legend.position = "none",
        plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12),
        axis.ticks.x = element_blank())

– High prevalence of subscription started in January can be partially a result of the fact that transaction dataset starts from January, 2015, so for users who registered at the service before that and have regular transactions, their first one in the data will be in this month.

srvfit <- survfit(Surv(time, churn) ~ month_period, trans_periods)
                  
p1 <- ggsurvplot(srvfit, 
                 palette = paletteer_c("grDevices::Set 2", 12), 
                 size = 1.2, conf.int = FALSE,
                 legend.labs = levels(trans_periods$month_period),
                 legend.title = "",
                 ggtheme = theme_minimal() + theme(plot.title = element_text(face = "bold")),
                 title = "Probability of active membership by\nmonth of subscription",
                 xlab = "Time from the first transaction, months",
                 ylab = "Probability of active membership",
                 legend = "bottom", censor = FALSE)

p1 <- p1$plot +
  scale_x_continuous(expand = c(0, 0), breaks = seq(0, 820, 30.4375*2),
                     labels = seq(0, 820/30.4375, 2),
                     limits = c(0, 820)) +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0,1,0.2),
                     labels = scales::percent_format(accuracy = 1)) +
  guides(color = guide_legend(nrow = 6)) +
  theme_classic() +
  theme(legend.position = c(0.9, 0.8),
        legend.title = element_blank(),
        legend.background = element_blank(),
        plot.title = element_text(face = "bold"),
        panel.grid.major = element_line(size = .2, colour = "grey90"),
        panel.grid.minor = element_line(size = .2, colour = "grey90"))

p1

– Subscriptions started in November and June (one of the most popular months for starting a subscription, after January) have a higher risk of being interrupted during the first month, those started in January and July - the less one.


3.2 Members’ characteristics


Members file contains the following information:


  • city - there is 21 unique numeric codes for cities.
plt_df <- members %>%
  group_by(city) %>%
  tally() %>%
  ungroup() %>%
  mutate(pct = n/sum(n)*100,
         lbl = sprintf("%.0f%%", pct),
         lbl = ifelse(lbl == "0%", "<1%", lbl))

members %>%
 ggplot() +
  aes(x = city) +
  geom_bar(aes(fill = city)) +
  geom_text(aes(x = city, y = n - 250, label = lbl), 
            plt_df %>% filter(pct > 2), 
            color = "white", fontface = "bold") +
  geom_text(aes(x = city, y = n + 300, label = lbl), 
            plt_df %>% filter(pct < 2), 
            color = "grey20", fontface = "bold") +
  labs(title = "City", y = "Number of unique users", x = "") +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0,10000,2000),
                     labels = c(0, paste0(seq(2,10,2), "K"))) +
  scale_fill_manual(values = c(paletteer_d("ggthemes::Tableau_10")[1:8],
                               rep("grey50", length(levels(members$city))-8))) +
  ggthemes::theme_tufte() +
  theme(legend.position = "none",
        plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12),
        axis.ticks.x = element_blank())

trans_periods <- trans_periods %>%
  left_join(members %>% select(msno, city, bd, gender, registered_via, registration_init_time,
                               year_reg, month_reg), by = "msno")

trans_periods <- trans_periods %>%
  mutate(time_from_reg = as.numeric(start_date - registration_init_time, 'days'),
         cityG = fct_lump_n(city, 8, other_level = "Other"))

srvfit <- survfit(Surv(time, churn) ~ cityG, trans_periods)
                  
p1 <- ggsurvplot(srvfit, 
                 palette = paletteer_d("ggthemes::Tableau_10")[1:length(levels(trans_periods$cityG))], 
                 size = 1.2, conf.int = FALSE, 
                 legend.labs = levels(trans_periods$cityG),
                 legend.title = "",
                 ggtheme = theme_minimal() + theme(plot.title = element_text(face = "bold")),
                 title = "Probability of active membership by user's city",
                 xlab = "Time from the first transaction, months",
                 ylab = "Probability of active membership",
                 legend = "bottom", censor = FALSE)

p1 <- p1$plot +
  scale_x_continuous(expand = c(0, 0), breaks = seq(0, 820, 30.4375*2),
                     labels = seq(0, 820/30.4375, 2),
                     limits = c(0, 820)) +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0,1,0.2),
                     labels = scales::percent_format(accuracy = 1)) +
  theme_classic() +
  theme(legend.position = "right",
        legend.background = element_blank(),
        plot.title = element_text(face = "bold"),
        panel.grid.major = element_line(size = .2, colour = "grey90"),
        panel.grid.minor = element_line(size = .2, colour = "grey90"))

p1

– There is no visible difference in patterns of survival probability between users from different cities, except a little bit more profound drop in memership during the first month of subscription for users from city 1 (the most populous in the dataset).


  • bd - age. It is an integer variable with somewhat strange distribution. I suppose users fill age or birth date field at the registration and may give an untruthful or error value here. Also it was not defined for which particular moment the age was calculated. Otherwise, I doubt it is a useful variable itself.
members <- members %>%
  mutate(bd_group = cut(bd, c(min(members$bd)-1, 10, 61, max(members$bd)+1), 
                        c("< 10", "10-60", "> 60"), right = FALSE))

var_label(members) <- list(bd = "Age", bd_group = "Age group")

tbl_summary(
  as.data.frame(select(members, bd, bd_group)),
  type = all_continuous() ~ "continuous2",
  statistic = list(all_continuous() ~ c("{mean} ({sd})", "{median} ({p25}-{p75})",  
                                        "{min}-{max}")),
  digits = list(all_continuous() ~ c(rep(1, 2), rep(0, 5)))) %>%
  bold_labels()
members <- members %>%
  mutate(bd_group = cut(bd, c(min(members$bd)-1, 10, 61, max(members$bd)+1), 
                        c("< 10", "10-60", "> 60"), right = FALSE))

var_label(members) <- list(bd = "Age", bd_group = "Age group")

tbl_summary(
  as.data.frame(select(members, bd, bd_group)),
  type = all_continuous() ~ "continuous2",
  statistic = list(all_continuous() ~ c("{mean} ({sd})", "{median} ({p25}-{p75})",  
                                        "{min}-{max}")),
  digits = list(all_continuous() ~ c(rep(1, 2), rep(0, 5)))) %>%
  bold_labels()
members %>%
 filter(bd_group == "10-60") %>%
 ggplot() +
  aes(x = bd) +
  geom_histogram(binwidth = 2, fill = "#4682B4") +
  labs(x = "Age", y = "Number of unique users", title = "Age (for age in [10, 60])") +
  scale_x_continuous(breaks = seq(10,70,10)) +
  scale_y_continuous(expand = c(0,0), breaks = seq(0,1000,250),
                     limits = c(0, 1000)) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12))
Characteristic N = 19200
Age
Mean (SD) 14.8 (17.7)
Median (25%-75%) 0 (0-28)
Minimum-Maximum 0-940
Age group
< 10 9691 (50%)
10-60 9437 (49%)
> 60 72 (0.4%)
trans_periods <- trans_periods %>%
  mutate(bdG = cut(bd, c(min(members$bd)-1, 10, max(members$bd)+1), 
                   c("< 10", ">= 10"), right = FALSE))

srvfit <- survfit(Surv(time, churn) ~ bdG, trans_periods)
                  
p1 <- ggsurvplot(srvfit, 
                 palette = paletteer_d("ggthemes::Jewel_Bright")[1:length(levels(trans_periods$cityG))], 
                 size = 1.2, conf.int = FALSE, 
                 legend.labs = levels(trans_periods$bdG),
                 legend.title = "",
                 ggtheme = theme_minimal() + theme(plot.title = element_text(face = "bold")),
                 title = "Probability of active membership by user's age",
                 xlab = "Time from the first transaction, months",
                 ylab = "Probability of active membership",
                 legend = "bottom", censor = FALSE)

p1 <- p1$plot +
  scale_x_continuous(expand = c(0, 0), breaks = seq(0, 820, 30.4375*2),
                     labels = seq(0, 820/30.4375, 2),
                     limits = c(0, 820)) +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0,1,0.2),
                     labels = scales::percent_format(accuracy = 1)) +
  theme_classic() +
  theme(legend.position = c(0.9, 0.8),
        legend.title = element_blank(),
        legend.background = element_blank(),
        plot.title = element_text(face = "bold"),
        panel.grid.major = element_line(size = .2, colour = "grey90"),
        panel.grid.minor = element_line(size = .2, colour = "grey90"))

p1


  • gender
plt_df <- members %>%
  group_by(gender) %>%
  tally() %>%
  ungroup() %>%
  mutate(pct = n/sum(n)*100,
         lbl = sprintf("%.1f%%", pct))

ggplot(members) +
  aes(x = gender, fill = gender) +
  geom_bar(width = 0.6) +
  geom_text(aes(x = gender, y = n - 500, label = lbl), plt_df, 
            color = "white", fontface = "bold") +
  scale_fill_manual(values = paletteer_d("ggthemes::Superfishel_Stone")[c(3,1,5)]) +
  scale_y_continuous(expand = c(0,0), breaks = seq(0,10000,2000),
                     limits = c(0, 10000), labels = c(0, paste0(seq(2,10,2), "K"))) +
  ggthemes::theme_tufte() +
  labs(x = "", y = "Number of unique users", title = "Gender") +
  theme(legend.position = "none") +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12),
        axis.ticks.x = element_blank())

srvfit <- survfit(Surv(time, churn) ~ gender, trans_periods)
                  
p1 <- ggsurvplot(srvfit, 
                 palette = paletteer_d("ggthemes::Superfishel_Stone")[c(3,1,5)], 
                 size = 1.2, conf.int = FALSE,
                 legend.labs = levels(trans_periods$gender),
                 legend.title = "",
                 ggtheme = theme_minimal() + theme(plot.title = element_text(face = "bold")),
                 title = "Probability of active membership by user's gender",
                 xlab = "Time from the first transaction, months",
                 ylab = "Probability of active membership",
                 legend = "bottom", censor = FALSE)

p1 <- p1$plot +
  scale_x_continuous(expand = c(0, 0), breaks = seq(0, 820, 30.4375*2),
                     labels = seq(0, 820/30.4375, 2),
                     limits = c(0, 820)) +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0,1,0.2),
                     labels = scales::percent_format(accuracy = 1)) +
  theme_classic() +
  theme(legend.position = c(0.9, 0.8),
        legend.title = element_blank(),
        legend.background = element_blank(),
        plot.title = element_text(face = "bold"),
        panel.grid.major = element_line(size = .2, colour = "grey90"),
        panel.grid.minor = element_line(size = .2, colour = "grey90"))

p1

– There are some strange differences between users with known (males and females) and unknown gender.


  • registered_via - registration method. There are 5 unique numeric codes for this variable without any explanations from organizers.
plt_df <- members %>%
  group_by(registered_via) %>%
  tally() %>%
  ungroup() %>%
  mutate(pct = n/sum(n)*100,
         lbl = sprintf("%.1f%%", pct))

members %>%
 ggplot() +
  aes(x = registered_via) +
  geom_bar(aes(fill = registered_via)) +
  geom_text(aes(x = registered_via, y = n - 200, label = lbl), 
            plt_df %>% filter(pct > 1), 
            color = "white", fontface = "bold") +
  geom_text(aes(x = registered_via, y = n + 200, label = lbl), 
            plt_df %>% filter(pct < 1), 
            color = "grey20", fontface = "bold") +
  labs(title = "Registration method", y = "Number of unique users", x = "") +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0, 6000, 1000),
                     labels = c(0, paste0(seq(1,6,1), "K"))) +
  scale_fill_manual(values = paletteer_d("ggthemes::Tableau_10")[1:5]) +
  ggthemes::theme_tufte() +
  theme(legend.position = "none",
        plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12),
        axis.ticks.x = element_blank())

srvfit <- survfit(Surv(time, churn) ~ registered_via, trans_periods)
                  
p1 <- ggsurvplot(srvfit, 
                 palette = paletteer_d("ggthemes::Tableau_10")[1:5], 
                 size = 1.2, conf.int = FALSE, 
                 legend.labs = levels(trans_periods$registered_via),
                 legend.title = "",
                 ggtheme = theme_minimal() + theme(plot.title = element_text(face = "bold")),
                 title = "Probability of active membership by\nregistration method",
                 xlab = "Time from the first transaction, months",
                 ylab = "Probability of active membership",
                 legend = "bottom", censor = FALSE)

p1 <- p1$plot +
  scale_x_continuous(expand = c(0, 0), breaks = seq(0, 820, 30.4375*2),
                     labels = seq(0, 820/30.4375, 2),
                     limits = c(0, 820)) +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0,1,0.2),
                     labels = scales::percent_format(accuracy = 1)) +
  theme_classic() +
  theme(legend.position = c(0.9, 0.8),
        legend.title = element_blank(),
        legend.background = element_blank(),
        plot.title = element_text(face = "bold"),
        panel.grid.major = element_line(size = .2, colour = "grey90"),
        panel.grid.minor = element_line(size = .2, colour = "grey90"))

p1

– There are differences in survival between users with different registration method, although I would say that patterns are similar: a significant drop in membership during and at the end of the first month of subscription followed by gradual decrease. For users registered via 3rd, 4th and 9th methods there is also a drop in subscriptions between 13 and 14 months.


  • registration_init_time - date of initial registration for each client, ranges from 2004-03-26 to 2017-02-28.
reg_days <- members %>%
  mutate(registration_init_time = floor_date(registration_init_time, "month")) %>%
  group_by(registration_init_time) %>%
  tally()

ggplot(reg_days) +
  aes(x = registration_init_time, y = n) +
  geom_line(size = 0.5, colour = "#112446") +
  scale_x_date(date_breaks = "1 year", date_labels =  "%Y") +
  # scale_y_continuous(limits = c(3000, 6000),
  #                    labels = c(paste0(seq(3,6,1), "K"))) +
  labs(title = "Number of registrations by month", x = "", y = "") +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12, hjust = 0.5, vjust = 0.5))

– There is a noticable trend to decreasing in the number of monthly registrations at KKBox service since 2016. Of course, we do not know for how many users information was deleted from the data by KKBox, but if this number did not depend on the year of registration, then we can suppose that we have here a representation of the monthly dynamics of KKBox popularity.

– Since 2011 there are peaks in registrations at the beginning of each year.


plt_df <- members %>%
  group_by(year_reg) %>%
  tally() %>%
  ungroup() %>%
  mutate(pct = n/sum(n)*100,
         lbl = sprintf("%.1f%%", pct))

members %>%
 ggplot() +
  aes(x = year_reg) +
  geom_bar(aes(fill = year_reg)) +
  geom_text(aes(x = year_reg, y = n - 200, label = lbl), 
            plt_df %>% filter(pct > 1.9), 
            color = "white", fontface = "bold") +
  geom_text(aes(x = year_reg, y = n + 200, label = lbl), 
            plt_df %>% filter(pct < 1.9), 
            color = "grey20", fontface = "bold") +
  labs(title = "Year of registration at the service", y = "Number of unique users", x = "") +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0, 6000, 1000),
                     labels = c(0, paste0(seq(1,6,1), "K"))) +
  scale_fill_manual(values = paletteer_c("ggthemes::Blue", 14)) +
  ggthemes::theme_tufte() +
  theme(legend.position = "none",
        plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12),
        axis.ticks.x = element_blank())

trans_periods <- trans_periods %>%
  mutate(year_regG = fct_collapse(year_reg, 
                                  "2004-2009" = as.character(2004:2009),
                                  "2010-2011" = as.character(2010:2011),
                                  "2012-2013" = as.character(2012:2013)))

srvfit <- survfit(Surv(time, churn) ~ year_regG, trans_periods)
                  
p1 <- ggsurvplot(srvfit, 
                 palette = paletteer_c("ggthemes::Blue", length(levels(trans_periods$year_regG))), 
                 size = 1.2, conf.int = FALSE,
                 legend.labs = levels(trans_periods$year_regG),
                 legend.title = "",
                 ggtheme = theme_minimal() + theme(plot.title = element_text(face = "bold")),
                 title = "Probability of active membership by\nuser's year of registration",
                 xlab = "Time from the first transaction, months",
                 ylab = "Probability of active membership",
                 legend = "bottom", censor = FALSE)

p1 <- p1$plot +
  scale_x_continuous(expand = c(0, 0), breaks = seq(0, 820, 30.4375*2),
                     labels = seq(0, 820/30.4375, 2),
                     limits = c(0, 820)) +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0,1,0.2),
                     labels = scales::percent_format(accuracy = 1)) +
  theme_classic() +
  theme(legend.position = c(0.9, 0.8),
        legend.title = element_blank(),
        legend.background = element_blank(),
        plot.title = element_text(face = "bold"),
        panel.grid.major = element_line(size = .2, colour = "grey90"),
        panel.grid.minor = element_line(size = .2, colour = "grey90"))

p1

– In comparison with users registered in 2010-2013 new ones have a more considerable drop in active membership durng the first month.


trans_periods <- trans_periods %>%
  mutate(time_from_regG = cut(time_from_reg, c(0,0.1, 31, 366, 366*3, 365*5, 366*15),
                              c("0", "1-30 days", "31 days-1year", "1-3 years", "3-5 years", ">5 years"),
                              right = FALSE))

var_label(trans_periods) <- list(time_from_reg = "Time from user's registration at the service to subscription start, days",
                                 time_from_regG = "Time from user's registration at the service to subscription start")

tbl_summary(
  as.data.frame(select(trans_periods, time_from_reg, time_from_regG)),
  type = all_continuous() ~ "continuous2",
  statistic = list(all_continuous() ~ c("{mean} ({sd})", "{median} ({p25}-{p75})",  
                                        "{min}-{max}")),
  digits = list(all_continuous() ~ c(rep(1, 2), rep(0, 5)))) %>%
  bold_labels()
Characteristic N = 25353
Time from user's registration at the service to subscription start, days
Mean (SD) 878.4 (1024.1)
Median (25%-75%) 502 (37-1326)
Minimum-Maximum 0-4610
Time from user's registration at the service to subscription start
0 4404 (17%)
1-30 days 1794 (7.1%)
31 days-1year 5007 (20%)
1-3 years 6436 (25%)
3-5 years 3685 (15%)
>5 years 4027 (16%)
srvfit <- survfit(Surv(time, churn) ~ time_from_regG, trans_periods)
                  
p1 <- ggsurvplot(srvfit, 
                 palette = paletteer_c("ggthemes::Brown", length(levels(trans_periods$time_from_regG))), 
                 size = 1.2, conf.int = FALSE,
                 legend.labs = levels(trans_periods$time_from_regG),
                 legend.title = "",
                 ggtheme = theme_minimal() + theme(plot.title = element_text(face = "bold")),
                 title = "Probability of active membership by\ntime from registration",
                 xlab = "Time from the first transaction, months",
                 ylab = "Probability of active membership",
                 legend = "bottom", censor = FALSE)

p1 <- p1$plot +
  scale_x_continuous(expand = c(0, 0), breaks = seq(0, 820, 30.4375*2),
                     labels = seq(0, 820/30.4375, 2),
                     limits = c(0, 820)) +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0,1,0.2),
                     labels = scales::percent_format(accuracy = 1)) +
  guides(color = guide_legend(nrow = 6)) +
  theme_classic() +
  theme(legend.position = c(0.9, 0.85),
        legend.title = element_blank(),
        legend.background = element_blank(),
        plot.title = element_text(face = "bold"),
        panel.grid.major = element_line(size = .2, colour = "grey90"),
        panel.grid.minor = element_line(size = .2, colour = "grey90"))

p1

– Initial subscriptions (those started at the day of user’s registration) tend to have least risk of attrition, while those which started between 1 month and 1 year from registration - the highest one.


plt_df <- members %>%
  group_by(month_reg) %>%
  tally() %>%
  ungroup() %>%
  mutate(pct = n/sum(n)*100,
         lbl = sprintf("%.1f%%", pct))

members %>%
 ggplot() +
  aes(x = month_reg) +
  geom_bar(aes(fill = month_reg)) +
  geom_text(aes(x = month_reg, y = n - 100, label = lbl), 
            plt_df %>% filter(pct > 1.9), 
            color = "white", fontface = "bold") +
  geom_text(aes(x = month_reg, y = n + 200, label = lbl), 
            plt_df %>% filter(pct < 1.9), 
            color = "grey20", fontface = "bold") +
  labs(title = "Month of registration at the service", y = "Number of unique users", x = "") +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0, 6000, 500),
                     limits = c(0, 2000)) +
  scale_fill_manual(values = paletteer_c("grDevices::Set 2", 12)) +
  ggthemes::theme_tufte() +
  theme(legend.position = "none",
        plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12),
        axis.ticks.x = element_blank())

srvfit <- survfit(Surv(time, churn) ~ month_reg, trans_periods)
                  
p1 <- ggsurvplot(srvfit, 
                 palette = paletteer_c("grDevices::Set 2", 12), 
                 size = 1.2, conf.int = FALSE,
                 legend.labs = levels(trans_periods$month_reg),
                 legend.title = "",
                 ggtheme = theme_minimal() + theme(plot.title = element_text(face = "bold")),
                 title = "Probability of active membership by\nuser's month of registration",
                 xlab = "Time from the first transaction, months",
                 ylab = "Probability of active membership",
                 legend = "bottom", censor = FALSE)

p1 <- p1$plot +
  scale_x_continuous(expand = c(0, 0), breaks = seq(0, 820, 30.4375*2),
                     labels = seq(0, 820/30.4375, 2),
                     limits = c(0, 820)) +
  scale_y_continuous(expand = c(0, 0), breaks = seq(0,1,0.2),
                     labels = scales::percent_format(accuracy = 1)) +
  guides(color = guide_legend(nrow = 6)) +
  theme_classic() +
  theme(legend.position = c(0.9, 0.8),
        legend.title = element_blank(),
        legend.background = element_blank(),
        plot.title = element_text(face = "bold"),
        panel.grid.major = element_line(size = .2, colour = "grey90"),
        panel.grid.minor = element_line(size = .2, colour = "grey90"))

p1

– I would not say that there are considerable differences in pattern of survival curve depending on the month of registration.


3.3 User logs


User log files contain 3959349 records on daily listening behaviors of sampled users from 2015-01-01 to 2017-03-31.


ulogs_days <- ulogs %>%
  group_by(date) %>%
  tally()

var_label(ulogs_days) <- list(n = "Number of records per day in user logs")

tbl_summary(
  as.data.frame(select(ulogs_days, n)),
  type = all_continuous() ~ "continuous2",
  statistic = list(all_continuous() ~ c("{mean} ({sd})", "{median} ({p25}-{p75})",  
                                        "{min}-{max}")),
  digits = list(all_continuous() ~ c(rep(1, 2), rep(0, 5)))) %>%
  bold_labels()
Characteristic N = 821
Number of records per day in user logs
Mean (SD) 4822.6 (658.4)
Median (25%-75%) 4811 (4188-5384)
Minimum-Maximum 3413-5988


ggplot(ulogs_days) +
  aes(x = date, y = n) +
  geom_line(size = 0.5, colour = "#112446") +
  scale_x_date(date_breaks = "1 month", date_labels =  "%Y-%m", 
               limits = as.Date(c('2015-01-01','2017-03-31'))) +
  scale_y_continuous(limits = c(3000, 6000),
                     labels = c(paste0(seq(3,6,1), "K"))) +
  labs(title = "Number of records per day in user logs", x = "", y = "") +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 10, angle = 90, hjust = 0.5, vjust = 0.5))


trans_periods_wide <- trans_periods %>%
  select(msno, msno_subscr_id, start_date, end_date) %>%
  as.data.frame() %>%
  reshape(v.names = c("start_date", "end_date"), timevar = "msno_subscr_id", 
          idvar = "msno", direction = "wide")

ulogs <- ulogs %>%
  left_join(trans_periods_wide, by = "msno")

ulogs <- ulogs %>%
  mutate(msno_subscr_period = case_when(date < start_date.1 ~ 0,
                                        date <= end_date.1 ~ 1,
                                        date < start_date.2 ~ 1.5,
                                        date <= end_date.2 ~ 2,
                                        date < start_date.3 ~ 2.5,
                                        date <= end_date.3 ~ 3,
                                        date < start_date.4 ~ 3.5,
                                        date <= end_date.4 ~ 4,
                                        date < start_date.5 ~ 4.5,
                                        date <= end_date.5 ~ 5,
                                        date < start_date.6 ~ 5.5,
                                        date <= end_date.6 ~ 6,
                                        date < start_date.7 ~ 6.5,
                                        date <= end_date.7 ~ 7,
                                        TRUE ~ 7.5))

ulogs <- ulogs %>%
  mutate(in_subscr = msno_subscr_period %in% 1:7)

ulogs_periods <- ulogs %>%
  filter(in_subscr) %>%
  rename(msno_subscr_id = msno_subscr_period) %>%
  left_join(trans_periods %>% select(msno, msno_subscr_id, subscr_id, start_date, end_date, churn, time,
                                     city, cityG, bd, bdG, gender, registered_via,
                                     registration_init_time, year_reg, year_regG, month_reg,
                                     churned_before, churned_before_num), 
            by = c("msno", "msno_subscr_id"))

trans_periods <- trans_periods %>%
  mutate(subscr_got_logs = subscr_id %in% ulogs_periods$subscr_id)

members <- members %>%
  left_join(trans_periods %>% 
              select(msno, subscr_got_logs) %>%
              group_by(msno) %>%
              summarise(n_subscr_logs = sum(subscr_got_logs),
                        n_subscr_nologs = sum(!subscr_got_logs)) %>%
              ungroup(),
            by = "msno") %>%
  mutate(all_subscr_got_logs = n_subscr_logs == n_subscr)

Statistics above concerning user logs was estimated on all records in user logs file on the sampled 19200 users, but after that for each log we identified to which of the previously constructed membership periods of the corresponding user they belong. There were 5.6% records which were dated earlier than the start of the user’s first subscription period, 0.4% which went after the expiration date for the user’s last subscription period, and 140.7% that fell between user’s subscription periods.

Taking into account the study nature of this project, our assumption that active membership began at the date of the first subscription for each user, as well as all examples that confirm the dirtiness of the data, and relatively small amount of logs in between subscriptions, we have decided to omit all records outside the identified subscription periods.

As a result, 3592849 records remained. They contain data on 18144 unique users (94.5% out of initially sampled) and 23536 unique subscription periods (92.8% of all identified periods). For the rest users and periods we will assume that they were without logging into the service any time.


ulogs_stat <- ulogs %>%
  group_by(msno) %>%
  summarise(n = sum(in_subscr))

var_label(ulogs_stat) <- list(n = "Number of records in user logs per ID")

tbl_summary(
  as.data.frame(select(ulogs_stat, n)),
  type = all_continuous() ~ "continuous2",
  statistic = list(all_continuous() ~ c("{mean} ({sd})", "{median} ({p25}-{p75})",  
                                        "{min}-{max}")),
  digits = list(all_continuous() ~ c(rep(1, 2), rep(0, 5)))) %>%
  bold_labels()
Characteristic N = 19200
Number of records in user logs per ID
Mean (SD) 187.1 (205.0)
Median (25%-75%) 106 (12-315)
Minimum-Maximum 0-807


All statistics on user logs below will concern only records remained after exclusion of records outside of users’ subscription periods.


– For 78.3% of all users’ successive records in user logs the break between them was 1 day.

ulogs_breaks <- ulogs_periods %>%
  filter(!is.na(time_from_prev_date)) %>%
  group_by(msno) %>%
  summarise(mean_break = mean(time_from_prev_date),
            median_break = median(time_from_prev_date)) %>%
  ungroup()

var_label(ulogs_breaks) <- list(mean_break = "Mean break between logs for one ID, days",
                                median_break = "Median break between logs for one ID, days")

tbl_summary(
  as.data.frame(select(ulogs_breaks, mean_break, median_break)),
  type = all_continuous() ~ "continuous2",
  statistic = list(all_continuous() ~ c("{mean} ({sd})", "{median} ({p25}-{p75})",  
                                        "{min}-{max}")),
  digits = list(all_continuous() ~ c(rep(1, 2), rep(0, 5)))) %>%
  bold_labels()
Characteristic N = 17022
Mean break between logs for one ID, days
Mean (SD) 4.5 (15.3)
Median (25%-75%) 1.8 (1.3-3.2)
Minimum-Maximum 1.0-522.0
Median break between logs for one ID, days
Mean (SD) 2.5 (13.4)
Median (25%-75%) 1.0 (1.0-1.0)
Minimum-Maximum 1.0-522.0


  • num_25, num_50, num_75, num_985, num_100 - number of songs played less than 25% of the song length, between 25% to 50%, between 50% to 75%, between 75% to 98.5% and over 98.5% of the song length, correspondingly.

    I’ve summed up all these numbers to obtain the total number of songs listened by a client - num_tot. There are no records with zero values - it confirms that logs were written only for active users.

    In addition, I’ve estimated the share of each songs’ length group in the total number of songs listened (sh_25, sh_50, sh_75, sh_985, sh_100) as \(\text{sh_X} = \text{num_x}/\text{num_tot}\) and average % of one song’s length played (av_pct_played) as \[\text{av_pct_played} = \frac{\text{num_25}*0.125 + \text{num_50}*0.375 + \text{num_75}*0.625 + \\\text{num_985}*0.8675 + \text{num_100}*0.9925}{\text{num_tot}}*100%\]


– Here is an example of several users’ activity during their subscription periods by the total number of songs played:

msno_sample_act <- members %>%
  select(msno, n_subscr, n_subscr_logs) %>%
  group_by(n_subscr, n_subscr_logs) %>%
  filter(row_number() == 1)

msno_sample_act <- trans_periods %>%
  select(msno, subscr_id, start_date, end_date) %>%
  filter(msno %in% msno_sample_act$msno) %>%
  left_join(ulogs_periods %>% select(subscr_id, date, num_tot, churn), by = "subscr_id") %>%
  mutate(msno_id = factor(msno, unique(msno_sample_act$msno), 1:length(unique(msno_sample_act$msno))))

ggplot() +
  geom_segment(aes(x = date, xend = date + 1, y = msno_id, yend = msno_id, color = num_tot), 
               msno_sample_act %>% filter(!is.na(num_tot)), size = 8) +
  scale_color_gradient(low = brewer.pal(9, "Reds")[1], high = brewer.pal(9, "Reds")[9],
                       breaks = seq(0, 250, 50), 
                       labels = seq(0, 250, 50)) +
  geom_point(aes(x = start_date, y = msno_id, shape = "Start of subscr."), msno_sample_act,
             color = "#35978F", size = 1) +
  geom_point(aes(x = end_date, y = msno_id, shape = "End of subscr."), msno_sample_act,
             color = "#FDAE61", size = 1) +
  scale_shape_manual(values = c(19, 15), breaks = c("Start of subscr.", "End of subscr.")) +
  scale_x_date(date_breaks = "1 month", date_labels =  "%Y-%m",
               expand = c(0.01,0)) +
  labs(x = "", y = "User", color = "Tot.num.of songs", title = "Sampled users' daily activity", shape = "") +
  guides(shape = guide_legend(override.aes = list(color = c("#35978F", "#FDAE61"),
                                                  size = 3))) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 9, angle = 90, hjust = 0.5, vjust = 0.5))


  • num_unq - number of unique songs played.

    I’ve divided the total number of songs by the number of unique songs played to obtain the average number of times of listening to one song (av_times_played). In addition, I’ve calculated the difference between the total number of songs and the number of unique songs played to obtain the number of repeats of the same songs (num_repeats).


p1 <- ulogs_periods %>%
  filter(num_25 < quantile(ulogs_periods$num_25, 0.99)) %>%
  ggplot() +
  aes(x = num_25) +
  geom_histogram(binwidth = 2, fill = "#4682B4") +
  labs(x = "", y = "", title = paste0(var_label(ulogs_periods)$num_25, "\n(observations with values less than 99th percentile)")) +
  scale_x_continuous(breaks = seq(0, 60, 10)) +
  scale_y_continuous(expand = c(0,0), breaks = seq(0, 1500000, 300000),
                     labels = c(0, paste0(seq(300,1500,300), "K"))) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12))

p2 <- ulogs_periods %>%
  filter(num_50 < quantile(ulogs_periods$num_50, 0.99)) %>%
  ggplot() +
  aes(x = num_50) +
  geom_histogram(binwidth = 1, fill = "#4682B4") +
  labs(x = "", y = "", title = paste0(var_label(ulogs_periods)$num_50, "\n(observations with values less than 99th percentile)")) +
  scale_x_continuous(breaks = seq(0, 16, 1)) +
  scale_y_continuous(expand = c(0,0), breaks = seq(0,1600000,400000),
                     labels = c(0, paste0(seq(400,1600,400), "K"))) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12))

p2a <- ulogs_periods %>%
  filter(num_75 < quantile(ulogs_periods$num_75, 0.99)) %>%
  ggplot() +
  aes(x = num_75) +
  geom_histogram(binwidth = 1, fill = "#4682B4") +
  labs(x = "", y = "", title = paste0(var_label(ulogs_periods)$num_75, "\n(observations with values less than 99th percentile)")) +
  scale_x_continuous(breaks = seq(0, 16, 1)) +
  scale_y_continuous(expand = c(0,0), breaks = seq(0,2000000,500000),
                     labels = c(0, paste0(seq(500,2000,500), "K"))) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12))

p3 <- 
ulogs_periods %>%
  filter(num_985 < quantile(ulogs_periods$num_985, 0.99)) %>%
  ggplot() +
  aes(x = num_985) +
  geom_histogram(binwidth = 1, fill = "#4682B4") +
  labs(x = "", y = "", title = paste0(var_label(ulogs_periods)$num_985, "\n(observations with values less than 99th percentile)")) +
  scale_x_continuous(breaks = seq(0, 16, 1)) +
  scale_y_continuous(expand = c(0,0), breaks = seq(0,2000000,500000),
                     labels = c(0, paste0(seq(500,2000,500), "K")),
                     limits = c(0, 2000000)) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12))

p4 <- ulogs_periods %>%
  filter(num_100 < quantile(ulogs_periods$num_100, 0.99)) %>%
  ggplot() +
  aes(x = num_100) +
  geom_histogram(binwidth = 5, fill = "#4682B4") +
  labs(x = "", y = "", title = paste0(var_label(ulogs_periods)$num_100, "\n(observations with values less than 99th percentile)")) +
  scale_x_continuous(breaks = seq(0, 180, 20)) +
  scale_y_continuous(expand = c(0,0), breaks = seq(0,1400000,300000),
                     labels = c(0, paste0(seq(300,1400,300), "K"))) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12))

p4a <- ulogs_periods %>%
  ggplot() +
  aes(x = av_pct_played) +
  geom_histogram(binwidth = 1, fill = "#4682B4") +
  labs(x = "", y = "", title = var_label(ulogs_periods)$av_pct_played) +
  scale_x_continuous(breaks = seq(0, 100, 20),
                     limits = c(0, 100)) +
  scale_y_continuous(expand = c(0,0), breaks = seq(0,600000,200000),
                     labels = c(0, paste0(seq(200,600,200), "K"))) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12))

p5 <- ulogs_periods %>%
  filter(num_tot < quantile(ulogs_periods$num_tot, 0.99)) %>%
  ggplot() +
  aes(x = num_tot) +
  geom_histogram(binwidth = 5, boundary = 0, fill = "#4682B4") +
  labs(x = "", y = "", title = paste0(var_label(ulogs_periods)$num_tot, "\n(observations with values less than 99th percentile)")) +
  scale_x_continuous(breaks = seq(0,200,20)) +
  scale_y_continuous(expand = c(0,0), breaks = seq(0,500000,100000),
                     labels = c(0, paste0(seq(100,500,100), "K")),
                     limits = c(0, 500000)) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12))

p6 <- ulogs_periods %>%
  filter(num_unq < quantile(ulogs_periods$num_unq, 0.99)) %>%
  ggplot() +
  aes(x = num_unq) +
  geom_histogram(binwidth = 5, boundary = 0, fill = "#4682B4") +
  labs(x = "", y = "", title = paste0(var_label(ulogs_periods)$num_unq, "\n(observations with values less than 99th percentile)")) +
  scale_x_continuous(breaks = seq(0,160,20)) +
  scale_y_continuous(expand = c(0,0), breaks = seq(0,1000000,200000),
                     labels = c(0, paste0(seq(200,1000,200), "K"))) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12))

p7 <- ulogs_periods %>%
  filter(num_repeats < quantile(ulogs_periods$num_repeats, 0.99)) %>%
  ggplot() +
  aes(x = num_repeats) +
  geom_histogram(binwidth = 5, fill = "#4682B4") +
  labs(x = "", y = "", title = paste0(var_label(ulogs_periods)$num_repeats, "\n(observations with values less than 99th percentile)")) +
  scale_x_continuous(breaks = seq(0,120,20)) +
  scale_y_continuous(expand = c(0,0), breaks = seq(0,2000000,500000),
                     labels = c(0, paste0(seq(500,2000,500), "K")),
                     limits = c(0, 2000000)) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12))

p8 <- ulogs_periods %>%
  filter(av_times_played < quantile(ulogs_periods$av_times_played, 0.99)) %>%
  ggplot() +
  aes(x = av_times_played) +
  geom_histogram(binwidth = 1, boundary = 0, fill = "#4682B4") +
  labs(x = "", y = "", title = paste0(var_label(ulogs_periods)$av_times_played, "\n(observations with values less than 99th percentile)")) +
  scale_x_continuous(breaks = seq(1,13,1)) +
  scale_y_continuous(expand = c(0,0), breaks = seq(0,3000000,1000000),
                     labels = c(0, paste0(seq(1,3,1), "KK"))) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12))

ggarrange(p1, p2, p2a, p3, p4, p4a, p5, p6, p7, p8, ncol = 2, nrow = 5)


ggplot(ulogs_periods) + 
  stat_ecdf(aes(x = sh_100, color = "98.5-100%"), size = 1) + 
  stat_ecdf(aes(x = sh_985, color = "75-98.5%"), size = 1) + 
  stat_ecdf(aes(x = sh_75, color = "50-75%"), size = 1) + 
  stat_ecdf(aes(x = sh_50, color = "25-50%"), size = 1) + 
  stat_ecdf(aes(x = sh_25, color = "0-25%"), size = 1) +
  scale_color_viridis_d(end = 0.9, direction = -1) + 
  scale_x_continuous(expand = c(0.01,0),
                     breaks = seq(0, 1, 0.2),
                     labels = c(0, paste0("<= ", seq(0.2,1,0.2))),
                     sec.axis = sec_axis(trans = ~ .,
                                         name = "For 1-CDF: share of this length songs out of the total number of songs played",
                                         breaks = seq(0, 1, 0.2),
                                         labels = c(paste0(">= ", seq(0,0.8,0.2)), 1))) +
  scale_y_continuous(expand = c(0.01, 0), 
                     breaks = seq(0, 1, 0.2),
                     labels = scales::percent_format(accuracy = 1),
                     sec.axis = sec_axis(trans = ~ 1 - .,
                                         name = "1-CDF",
                                         breaks = seq(0, 1, 0.2),
                                         labels = scales::percent_format(accuracy = 1))) +
  labs(x = "For CDF: share of this length songs out of the total number of songs played",
       y = "CDF", color = "% of song length played",
       title = "Cumulative distribution function for shares of\nthe number of songs played by length") +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        panel.grid.major = element_line(color = "grey90", size = 0.2),
        panel.grid.minor = element_line(color = "grey90", size = 0.2))

– There are almost no records, when during a day a user has not played any song in the full length. About 15% of records contains only completely played songs. For about 80% of records at least half of songs were listened completely, for 50% of records - at least 8 out of 10.

– Intermediate categories (num_50, num_75 and num_985) are quite rare and for more than 95% of records their share is less than 30% out of the total number of songs played that day.


ggplot(ulogs_periods, aes(y = av_pct_played)) +
  ggdist::stat_halfeye(adjust = 0.4, justification = -0.07, .width = 0, 
                       point_colour = NA, alpha = 0.8,
                       fill = "#4682B4") +
  geom_boxplot(width = 0.1, alpha = 0.5, outlier.colour = NA,
               show.legend = FALSE, fill = "#4682B4", color = "#4682B4") +
  coord_flip() +
  scale_y_continuous(breaks = seq(10,100,10),
                     expand = c(0.01,0),
                     limits = c(10, 100)) +
  scale_x_continuous(breaks = seq(0.06, 1.06, 0.2),
                     labels = seq(0,1,0.2),
                     expand = c(0.01, 0),
                     limits = c(-0.1, 1.1)) +
  labs(y = var_label(ulogs_periods)$av_pct_played, x = "Density",
       title = "Average % of one song's length played") +
  ggthemes::theme_tufte()


ulogs_periods <- ulogs_periods %>%
  mutate(total_secs_na = case_when(total_secs < 0 | total_secs > 60*60*24 ~ NA_real_, 
                                   TRUE ~ total_secs),
         total_hrs_na = total_secs_na/3600,
         av_secs_real_na = total_secs / (num_25/0.125 + num_50/0.375 + num_75/0.625 + num_985/0.8675 + num_100/0.9925))
  • total_secs - total seconds played. This column contains 561 records with negative values - they were replaced by missings. The same was done to 1254 records with values higher than 86400 seconds (24 hours).

    To impute these missings I will use the following approach: I’ve calculated the average real length of one song listened by each user (av_length_real) with the following formula: \[\text{Average real length} = \frac{\text{total_secs}}{\frac{\text{num_25}}{0.125}+\frac{\text{num_50}}{0.375}+\frac{\text{num_75}}{0.625}+\frac{\text{num_985}}{0.8675}+\frac{\text{num_100}}{0.9925}}\]

    Median average real length of one song in the whole dataset is about 114 seconds (median_length). Using it I’ve calculated the imputed value of the total seconds played (total_secs_imp) for observations with missing values in the following way: \[\text{total_secs_imp} = min(24*60*60, \text{median_length}*\text{sh_tot}) = \\= min(24*60*60, \text{median_length}*(\text{num_25}*0.125 + \text{num_50}*0.375 + \\\text{num_75}*0.625 + \text{num_985}*0.8675 + \text{num_100}*0.9925))\]


ulogs_periods <- ulogs_periods %>%
  mutate(total_secs_imp = ifelse(is.na(total_secs_na),
                                 min(round(median(ulogs_periods$av_secs_real_na, na.rm = TRUE),0)*av_pct_played/100,86400),
                                 total_secs_na),
         total_hrs_imp = total_secs_imp/3600,
         av_secs_real = total_secs_imp / (num_25/0.125 + num_50/0.375 + num_75/0.625 + num_985/0.8675 + num_100/0.9925),
         av_secs_played = total_secs_imp / num_tot,
         av_mins_real = av_secs_real/60,
         av_mins_played = av_secs_played/60)

var_label(ulogs_periods) <- list(total_secs_na = "Total seconds played",
                                 total_secs_imp = "Total seconds played (imputed)",
                                 total_hrs_na = "Total hours played",
                                 total_hrs_imp = "Total hours played (imputed)",
                                 av_secs_real = "Average real length of a song, seconds",
                                 av_secs_played = "Average time spent on listening to one song, seconds",
                                 av_mins_real = "Average real length of a song, minutes",
                                 av_mins_played = "Average time spent on listening to one song, minutes")

tbl_summary(
  as.data.frame(select(ulogs_periods, av_mins_played, total_hrs_na, total_hrs_imp)),
  type = all_continuous() ~ "continuous2",
  statistic = list(all_continuous() ~ c("{mean} ({sd})", "{median} ({p25}-{p75})",  
                                        "{min}-{max}")),
  digits = list(all_continuous() ~ c(rep(1, 7)))) %>%
  bold_labels()
Characteristic N = 3592849
Average time spent on listening to one song, minutes
Mean (SD) 3.3 (1.1)
Median (25%-75%) 3.4 (2.7-3.9)
Minimum-Maximum 0.0-575.2
Total hours played
Mean (SD) 2.3 (2.7)
Median (25%-75%) 1.3 (0.6-2.9)
Minimum-Maximum 0.0-24.0
Unknown 1815
Total hours played (imputed)
Mean (SD) 2.3 (2.7)
Median (25%-75%) 1.3 (0.6-2.9)
Minimum-Maximum 0.0-24.0


p1 <- ulogs_periods %>%
  filter(av_mins_played < quantile(ulogs_periods$av_mins_played, 0.99, na.rm = TRUE)) %>%
  ggplot() +
  aes(x = av_mins_played) +
  geom_histogram(binwidth = 1, boundary = 0, fill = "#4682B4") +
  labs(x = "Minutes", y = "", title = paste0(var_label(ulogs_periods)$av_mins_played, "\n(observations with values less than 99th percentile)")) +
  scale_x_continuous(breaks = seq(0, 5, 1)) +
  scale_y_continuous(expand = c(0,0), breaks = seq(0,1500000,500000), 
                     labels = c(0, paste0(seq(500,1500,500), "K"))) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12))

p2 <- ulogs_periods %>%
  filter(total_hrs_imp < quantile(ulogs_periods$total_hrs_imp, 0.99, na.rm = TRUE)) %>%
  ggplot() +
  aes(x = total_hrs_imp) +
  geom_histogram(binwidth = 1, boundary = 0, fill = "#4682B4") +
  labs(x = "Hours", y = "", title = paste0(var_label(ulogs_periods)$total_hrs_imp, "\n(observations with values less than 99th percentile)")) +
  scale_x_continuous(breaks = seq(0, 12, 1)) +
  scale_y_continuous(expand = c(0,0), breaks = seq(0,1500000,500000), 
                     labels = c(0, paste0(seq(500,1500,500), "K"))) +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 12))

ggarrange(p1,p2, nrow = 1, ncol = 2)


p1 <- dbplot_raster(ulogs_periods,
              x = num_tot, y = total_hrs_imp) +
  scale_y_continuous(expand = c(0,0), breaks = seq(0,24,3), limits = c(0,24)) +
  scale_x_continuous(expand = c(0,0), breaks = seq(0,6000,1000)) +
  scale_fill_viridis_c() +
  labs(x = "Total number of songs played", y = "Total hours played",
       title = "Total hours played vs. total numer of songs",
       fill = "Number of obs.") +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold"),
        legend.key.width = unit(1,"cm"))

p1


In addition to above statistics we have estimated how much days out of the corresponding subscription periods a user logged into the service (was active) - days_active, and the share of active days in the full length of the subscription period (in %) - pct_active. We did the same for the last 30, 14 and 7 days of each subscription period.

ulogs_last <- ulogs_periods %>%
  group_by(subscr_id) %>%
  summarise(days_active = n(),
            mean_num_unq = sum(num_unq),
            mean_num_tot = sum(num_tot),
            mean_av_pct_played = sum(av_pct_played),
            mean_total_secs_imp = sum(total_secs_imp),
            mean_total_hrs_imp = sum(total_hrs_imp),
            mean_av_secs_played = sum(av_secs_played),
            mean_av_mins_played = sum(av_mins_played),
            mean_sh_25 = sum(sh_25),
            mean_sh_985 = sum(sh_985),
            mean_sh_100 = sum(sh_100),
            last30d = as.numeric(end_date - date, 'days') < 30,
            last14d = as.numeric(end_date - date, 'days') < 14,
            last7d = as.numeric(end_date - date, 'days') < 7,
            days_active30 = sum(last30d),
            mean_num_unq30 = sum(num_unq[last30d]),
            mean_num_tot30 = sum(num_tot[last30d]),
            mean_av_pct_played30 = sum(av_pct_played[last30d]),
            mean_total_secs_imp30 = sum(total_secs_imp[last30d]),
            mean_total_hrs_imp30 = sum(total_hrs_imp[last30d]),
            mean_av_secs_played30 = sum(av_secs_played[last30d]),
            mean_av_mins_played30 = sum(av_mins_played[last30d]),
            mean_sh_2530 = sum(sh_25[last30d]),
            mean_sh_98530 = sum(sh_985[last30d]),
            mean_sh_10030 = sum(sh_100[last30d]),
            days_active14 = sum(last14d),
            mean_num_unq14 = sum(num_unq[last14d]),
            mean_num_tot14 = sum(num_tot[last14d]),
            mean_av_pct_played14 = sum(av_pct_played[last14d]),
            mean_total_secs_imp14 = sum(total_secs_imp[last14d]),
            mean_total_hrs_imp14 = sum(total_hrs_imp[last14d]),
            mean_av_secs_played14 = sum(av_secs_played[last14d]),
            mean_av_mins_played14 = sum(av_mins_played[last14d]),
            mean_sh_2514 = sum(sh_25[last14d]),
            mean_sh_98514 = sum(sh_985[last14d]),
            mean_sh_10014 = sum(sh_100[last14d]),
            days_active7 = sum(last7d),
            mean_num_unq7 = sum(num_unq[last7d]),
            mean_num_tot7 = sum(num_tot[last7d]),
            mean_av_pct_played7 = sum(av_pct_played[last7d]),
            mean_total_secs_imp7 = sum(total_secs_imp[last7d]),
            mean_total_hrs_imp7 = sum(total_hrs_imp[last7d]),
            mean_av_secs_played7 = sum(av_secs_played[last7d]),
            mean_av_mins_played7 = sum(av_mins_played[last7d]),
            mean_sh_257 = sum(sh_25[last7d]),
            mean_sh_9857 = sum(sh_985[last7d]),
            mean_sh_1007 = sum(sh_100[last7d])) %>%
  select(-last30d, -last14d, -last7d) %>%
  distinct()

trans_periods <- trans_periods %>%
  left_join(ulogs_last, by = "subscr_id") 

trans_periods <- trans_periods %>%
  mutate(days_active = replace_na(days_active, 0),
         pct_active = replace_na(days_active / time * 100, 0),
         tot_num_tot = replace_na(mean_num_tot,0),
         tot_total_secs_imp = replace_na(mean_total_secs_imp,0),
         mean_sh_25_act = replace_na(mean_sh_25/days_active, 0),
         mean_sh_25 = replace_na(mean_sh_25/time, 0),
         mean_sh_985_act = replace_na(mean_sh_985/days_active, 0),
         mean_sh_985 = replace_na(mean_sh_985/time, 0),
         mean_sh_100_act = replace_na(mean_sh_100/days_active, 0),
         mean_sh_100 = replace_na(mean_sh_100/time, 0),
         mean_num_unq_act = replace_na(mean_num_unq/days_active, 0),
         mean_num_unq = replace_na(mean_num_unq/time, 0),
         mean_num_tot_act = replace_na(mean_num_tot/days_active, 0),
         mean_num_tot = replace_na(mean_num_tot/time, 0),
         mean_av_pct_played_act = replace_na(mean_av_pct_played/days_active, 0),
         mean_av_pct_played = replace_na(mean_av_pct_played/time, 0),
         mean_total_secs_imp_act = replace_na(mean_total_secs_imp/days_active, 0),
         mean_total_secs_imp = replace_na(mean_total_secs_imp/time, 0),
         mean_total_hrs_imp_act = replace_na(mean_total_hrs_imp/days_active, 0),
         mean_total_hrs_imp = replace_na(mean_total_hrs_imp/time, 0),
         mean_av_secs_played_act = replace_na(mean_av_secs_played/days_active, 0),
         mean_av_secs_played = replace_na(mean_av_secs_played/time, 0),
         mean_av_mins_played_act = replace_na(mean_av_mins_played/days_active, 0),
         mean_av_mins_played = replace_na(mean_av_mins_played/time, 0),
         days_active30 = replace_na(days_active30, 0),
         pct_active30 = replace_na(days_active30 / 30 * 100, 0),
         tot_num_tot30 = replace_na(mean_num_tot30,0),
         tot_total_secs_imp30 = replace_na(mean_total_secs_imp30, 0),
         mean_sh_25_act30 = replace_na(mean_sh_2530/days_active, 0),
         mean_sh_2530 = replace_na(mean_sh_2530/time, 0),
         mean_sh_985_act30 = replace_na(mean_sh_98530/days_active, 0),
         mean_sh_98530 = replace_na(mean_sh_98530/time, 0),
         mean_sh_100_act30 = replace_na(mean_sh_10030/days_active, 0),
         mean_sh_10030 = replace_na(mean_sh_10030/time, 0),
         mean_num_unq_act30 = replace_na(mean_num_unq30/days_active30, 0),
         mean_num_unq30 = replace_na(mean_num_unq30/30, 0),
         mean_num_tot_act30 = replace_na(mean_num_tot30/days_active30, 0),
         mean_num_tot30 = replace_na(mean_num_tot30/30, 0),
         mean_av_pct_played_act30 = replace_na(mean_av_pct_played30/days_active30, 0),
         mean_av_pct_played30 = replace_na(mean_av_pct_played30/30, 0),
         mean_total_secs_imp_act30 = replace_na(mean_total_secs_imp30/days_active30, 0),
         mean_total_secs_imp30 = replace_na(mean_total_secs_imp30/30, 0),
         mean_total_hrs_imp_act30 = replace_na(mean_total_hrs_imp30/days_active30, 0),
         mean_total_hrs_imp30 = replace_na(mean_total_hrs_imp30/30, 0),
         mean_av_secs_played_act30 = replace_na(mean_av_secs_played30/days_active30, 0),
         mean_av_secs_played30 = replace_na(mean_av_secs_played30/30, 0),
         mean_av_mins_played_act30 = replace_na(mean_av_mins_played30/days_active30, 0),
         mean_av_mins_played30 = replace_na(mean_av_mins_played30/30, 0),
         days_active14 = replace_na(days_active14, 0),
         pct_active14 = replace_na(days_active14 / 14 * 100, 0),
         tot_num_tot14 = replace_na(mean_num_tot14,0),
         tot_total_secs_imp14 = replace_na(mean_total_secs_imp14, 0),
         mean_sh_25_act14 = replace_na(mean_sh_2514/days_active, 0),
         mean_sh_2514 = replace_na(mean_sh_2514/time, 0),
         mean_sh_985_act14 = replace_na(mean_sh_98514/days_active, 0),
         mean_sh_98514 = replace_na(mean_sh_98514/time, 0),
         mean_sh_100_act14 = replace_na(mean_sh_10014/days_active, 0),
         mean_sh_10014 = replace_na(mean_sh_10014/time, 0),
         mean_num_unq_act14 = replace_na(mean_num_unq14/days_active14, 0),
         mean_num_unq14 = replace_na(mean_num_unq14/14, 0),
         mean_num_tot_act14 = replace_na(mean_num_tot14/days_active14, 0),
         mean_num_tot14 = replace_na(mean_num_tot14/14, 0),
         mean_av_pct_played_act14 = replace_na(mean_av_pct_played14/days_active14, 0),
         mean_av_pct_played14 = replace_na(mean_av_pct_played14/14, 0),
         mean_total_secs_imp_act14 = replace_na(mean_total_secs_imp14/days_active14, 0),
         mean_total_secs_imp14 = replace_na(mean_total_secs_imp14/14, 0),
         mean_total_hrs_imp_act14 = replace_na(mean_total_hrs_imp14/days_active14, 0),
         mean_total_hrs_imp14 = replace_na(mean_total_hrs_imp14/14, 0),
         mean_av_secs_played_act14 = replace_na(mean_av_secs_played14/days_active14, 0),
         mean_av_secs_played14 = replace_na(mean_av_secs_played14/14, 0),
         mean_av_mins_played_act14 = replace_na(mean_av_mins_played14/days_active14, 0),
         mean_av_mins_played14 = replace_na(mean_av_mins_played14/14, 0),
         days_active7 = replace_na(days_active7, 0),
         pct_active7 = replace_na(days_active7 / 7 * 100, 0),
         tot_num_tot7 = replace_na(mean_num_tot7,0),
         tot_total_secs_imp7 = replace_na(mean_total_secs_imp7, 0),
         mean_sh_25_act7 = replace_na(mean_sh_257/days_active, 0),
         mean_sh_257 = replace_na(mean_sh_257/time, 0),
         mean_sh_985_act7 = replace_na(mean_sh_9857/days_active, 0),
         mean_sh_9857 = replace_na(mean_sh_9857/time, 0),
         mean_sh_100_act7 = replace_na(mean_sh_1007/days_active, 0),
         mean_sh_1007 = replace_na(mean_sh_1007/time, 0),
         mean_num_unq_act7 = replace_na(mean_num_unq7/days_active7, 0),
         mean_num_unq7 = replace_na(mean_num_unq7/7, 0),
         mean_num_tot_act7 = replace_na(mean_num_tot7/days_active7, 0),
         mean_num_tot7 = replace_na(mean_num_tot7/7, 0),
         mean_av_pct_played_act7 = replace_na(mean_av_pct_played7/days_active7, 0),
         mean_av_pct_played7 = replace_na(mean_av_pct_played7/7, 0),
         mean_total_secs_imp_act7 = replace_na(mean_total_secs_imp7/days_active7, 0),
         mean_total_secs_imp7 = replace_na(mean_total_secs_imp7/7, 0),
         mean_total_hrs_imp_act7 = replace_na(mean_total_hrs_imp7/days_active7, 0),
         mean_total_hrs_imp7 = replace_na(mean_total_hrs_imp7/7, 0),
         mean_av_secs_played_act7 = replace_na(mean_av_secs_played7/days_active7, 0),
         mean_av_secs_played7 = replace_na(mean_av_secs_played7/7, 0),
         mean_av_mins_played_act7 = replace_na(mean_av_mins_played7/days_active7, 0),
         mean_av_mins_played7 = replace_na(mean_av_mins_played7/7, 0))

trans_lags <- trans_periods %>%
  arrange(msno, msno_subscr_id) %>%
  group_by(msno) %>%
  filter(n() > 1) %>%
  transmute(subscr_id = unique(subscr_id),
            
            lag_time = replace_na(lag(time), 0),
            prev_time = cumsum(replace_na(lag(time), 0)),
            
            lag_days_active = replace_na(lag(days_active), 0),
            prev_days_active = cumsum(replace_na(lag(days_active), 0)),
            
            lag_pct_active = replace_na(lag(pct_active), 0),
            prev_pct_active = replace_na(prev_days_active/prev_time, 0),
            
            lag_tot_num_tot = replace_na(lag(tot_num_tot), 0),
            prev_tot_num_tot = cumsum(replace_na(lag(tot_num_tot), 0)),
            
            lag_tot_total_secs_imp = replace_na(lag(tot_total_secs_imp), 0),
            prev_tot_total_secs_imp = cumsum(replace_na(lag(tot_total_secs_imp), 0)),
            
            lag_mean_num_tot = replace_na(lag_tot_num_tot/lag_time, 0),
            prev_mean_num_tot = replace_na(prev_tot_num_tot/prev_time, 0),
            
            lag_mean_num_tot_act = replace_na(lag_tot_num_tot/lag_days_active, 0),
            prev_mean_num_tot_act = replace_na(prev_tot_num_tot/prev_days_active, 0),
            
            lag_mean_total_secs_imp = replace_na(lag_tot_total_secs_imp/lag_time, 0),
            prev_mean_total_secs_imp = replace_na(prev_tot_total_secs_imp/prev_time, 0),
            
            lag_mean_total_secs_imp_act = replace_na(lag_tot_total_secs_imp/lag_days_active, 0),
            prev_mean_total_secs_imp_act = replace_na(prev_tot_total_secs_imp/prev_days_active, 0),
            
            lag_mean_av_pct_played = replace_na(lag(mean_av_pct_played), 0),
            prev_mean_av_pct_played = replace_na(cumsum(replace_na(lag(mean_av_pct_played), 0))/(n()-1)),
            
            lag_mean_av_pct_played_act = replace_na(lag(mean_av_pct_played_act), 0),
            prev_mean_av_pct_played_act = replace_na(cumsum(replace_na(lag(mean_av_pct_played_act), 0))/(n()-1)),
            
            lag_mean_num_unq = replace_na(lag(mean_num_unq), 0),
            prev_mean_num_unq = replace_na(cumsum(replace_na(lag(mean_num_unq), 0))/(n()-1)),
            
            lag_mean_num_unq_act = replace_na(lag(mean_num_unq_act), 0),
            prev_mean_num_unq_act = replace_na(cumsum(replace_na(lag(mean_num_unq_act), 0))/(n()-1)),
            
            lag_mode_payment_id = factor(replace_na(lag(as.numeric(as.character(mode_payment_method_id))), 0),
                                         c(0, as.numeric(levels(trans_df$payment_method_id)))),
            prev_mode_payment_id = factor(replace_na(mode(as.numeric(as.character(mode_payment_method_id[-n()]))), 0),
                                          c(0, as.numeric(levels(trans_df$payment_method_id)), 999)),
            
            lag_sum_actual_amount_paid = replace_na(lag(sum_actual_amount_paid), 0),
            prev_sum_actual_amount_paid = cumsum(replace_na(lag(sum_actual_amount_paid), 0)),
            
            lag_mean_actual_amount_paid = replace_na(lag_sum_actual_amount_paid/lag_time, 0),
            prev_mean_actual_amount_paid = replace_na(prev_sum_actual_amount_paid/prev_time, 0),
            
            lag_num_cancel = replace_na(lag(n_cancel), 0),
            prev_num_cancel = cumsum(replace_na(lag(n_cancel), 0))) %>%
  ungroup()

trans_periods <- trans_periods %>%
  left_join(trans_lags, by = c("msno", "subscr_id")) %>%
  mutate_at(vars(contains("lag_"), dplyr::starts_with("prev_")), ~ replace_na(., 0))

trans_periods <- trans_periods %>%
  mutate(lag_mode_payment_id = fct_lump_min(lag_mode_payment_id, min = 30, other_level = "999"),
         prev_mode_payment_id = fct_lump_min(lag_mode_payment_id, min = 30, other_level = "999"))
var_label(trans_periods) <- list(time = "Subscription duration, days",
                                 pct_active = "%of active days in subscription",
                                 pct_active30 = "%of active days in the last 30 days of subscription",
                                 pct_active14 = "%of active days in the last 14 days of subscription",
                                 pct_active7 = "%of active days in the last 7 days of subscription")

tbl_summary(
  as.data.frame(trans_periods %>%
                  select(churn, time, contains("pct_active")) %>%
                  mutate(churn = factor(churn, labels = c("Churn = 0", "Churn = 1")))),
  by = "churn", 
  type = list(all_continuous() ~ "continuous2",
              pct_active7 ~ "continuous2"),
  statistic = list(all_continuous() ~ c("{mean} ({sd})", "{median} ({p25}-{p75})",  
                                        "{min}-{max}")),
  digits = list(all_continuous() ~ c(rep(1, 5), rep(0, 2)),
                time ~ c(rep(1, 2), rep(0, 5)))) %>%
  modify_header(all_stat_cols() ~ "**{level}**<br>N = {n}") %>%
  add_overall(col_label = "**Total**<br>N = {N}", last = TRUE) %>%
  modify_footnote(update = everything() ~ NA) %>%
  bold_labels()
Characteristic Churn = 0
N = 10750
Churn = 1
N = 14603
Total
N = 25353
Subscription duration, days
Mean (SD) 431.3 (262.1) 130.0 (168.4) 257.8 (260.1)
Median (25%-75%) 416 (226-643) 31 (11-196) 172 (31-423)
Minimum-Maximum 1-821 1-787 1-821
%of active days in subscription
Mean (SD) 55.4 (29.3) 45.3 (32.2) 49.6 (31.4)
Median (25%-75%) 59.7 (32.0-80.6) 45.2 (14.3-74.1) 51.6 (21.2-77.1)
Minimum-Maximum 0-100 0-100 0-100
%of active days in the last 30 days of subscription
Mean (SD) 54.4 (34.0) 37.5 (34.2) 44.7 (35.2)
Median (25%-75%) 60.0 (23.3-86.7) 26.7 (3.3-70.0) 43.3 (10.0-76.7)
Minimum-Maximum 0-100 0-100 0-100
%of active days in the last 14 days of subscription
Mean (SD) 55.1 (35.4) 38.1 (34.6) 45.3 (35.9)
Median (25%-75%) 64.3 (21.4-85.7) 28.6 (7.1-71.4) 42.9 (7.1-78.6)
Minimum-Maximum 0-100 0-100 0-100
%of active days in the last 7 days of subscription
Mean (SD) 54.7 (36.9) 38.6 (36.7) 45.5 (37.6)
Median (25%-75%) 57.1 (14.3-85.7) 28.6 (0.0-71.4) 42.9 (0.0-85.7)
Minimum-Maximum 0-100 0-100 0-100
lag_pct_active
Mean (SD) 15.9 (29.8) 11.1 (25.7) 13.2 (27.6)
Median (25%-75%) 0.0 (0.0-13.8) 0.0 (0.0-0.0) 0.0 (0.0-0.0)
Minimum-Maximum 0-100 0-100 0-100
prev_pct_active
Mean (SD) 0.2 (0.3) 0.1 (0.3) 0.1 (0.3)
Median (25%-75%) 0.0 (0.0-0.1) 0.0 (0.0-0.0) 0.0 (0.0-0.0)
Minimum-Maximum 0-1 0-1 0-1


p1 <- trans_periods %>% 
  mutate(churn = factor(churn, labels = c("Churn = 0", "Churn = 1"))) %>%
  ggplot() +
  aes(x = pct_active, fill = churn) +
  geom_density(alpha = 0.4, color = "white") +
  labs(x = "", y = "Density", title = "% of active days in subscription", fill = "") +
  scale_x_continuous(breaks = seq(0, 100, 10)) +
  scale_y_continuous(expand = c(0, 0)) +
  ggthemes::theme_tufte() +
  theme(legend.position = c(0.5, 0.9),
        plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 10))

p2 <- trans_periods %>% 
  mutate(churn = factor(churn, labels = c("Churn = 0", "Churn = 1"))) %>%
  ggplot() +
  aes(x = pct_active30, fill = churn) +
  geom_density(alpha = 0.4, color = "white") +
  labs(x = "", y = "Density", title = "% of active days in the last 30 days of subscr.", fill = "") +
  scale_x_continuous(breaks = seq(0, 100, 10)) +
  scale_y_continuous(expand = c(0, 0)) +
  ggthemes::theme_tufte() +
  theme(legend.position = c(0.5, 0.9),
        plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 10))

p3 <- trans_periods %>% 
  mutate(churn = factor(churn, labels = c("Churn = 0", "Churn = 1"))) %>%
  ggplot() +
  aes(x = pct_active14, fill = churn) +
  geom_density(alpha = 0.4, color = "white") +
  labs(x = "", y = "Density", title = "% of active days in the last 14 days of subscr.", fill = "") +
  scale_x_continuous(breaks = seq(0, 100, 10)) +
  scale_y_continuous(expand = c(0, 0)) +
  ggthemes::theme_tufte() +
  theme(legend.position = c(0.5, 0.9),
        plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 10))

p4 <- trans_periods %>% 
  mutate(churn = factor(churn, labels = c("Churn = 0", "Churn = 1"))) %>%
  ggplot() +
  aes(x = pct_active7, fill = churn) +
  geom_density(alpha = 0.4, color = "white") +
  labs(x = "", y = "Density", title = "% of active days in the last 7 days of subscr.", fill = "") +
  scale_x_continuous(breaks = seq(0, 100, 10)) +
  scale_y_continuous(expand = c(0, 0)) +
  ggthemes::theme_tufte() +
  theme(legend.position = c(0.5, 0.9),
        plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 10))

ggarrange(p1,p2,p3,p4, nrow = 2, ncol = 2, common.legend = TRUE)
p1 <- trans_periods %>% 
  mutate(churn = factor(churn, labels = c("Churn = 0", "Churn = 1"))) %>%
  ggplot() +
  aes(x = pct_active, fill = churn) +
  geom_density(alpha = 0.4, color = "white") +
  labs(x = "", y = "Density", title = "% of active days in subscription", fill = "") +
  scale_x_continuous(breaks = seq(0, 100, 10)) +
  scale_y_continuous(expand = c(0, 0)) +
  ggthemes::theme_tufte() +
  theme(legend.position = c(0.5, 0.9),
        plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 10))

p2 <- trans_periods %>% 
  mutate(churn = factor(churn, labels = c("Churn = 0", "Churn = 1"))) %>%
  ggplot() +
  aes(x = pct_active30, fill = churn) +
  geom_density(alpha = 0.4, color = "white") +
  labs(x = "", y = "Density", title = "% of active days in the last 30 days of subscr.", fill = "") +
  scale_x_continuous(breaks = seq(0, 100, 10)) +
  scale_y_continuous(expand = c(0, 0)) +
  ggthemes::theme_tufte() +
  theme(legend.position = c(0.5, 0.9),
        plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 10))

p3 <- trans_periods %>% 
  mutate(churn = factor(churn, labels = c("Churn = 0", "Churn = 1"))) %>%
  ggplot() +
  aes(x = pct_active14, fill = churn) +
  geom_density(alpha = 0.4, color = "white") +
  labs(x = "", y = "Density", title = "% of active days in the last 14 days of subscr.", fill = "") +
  scale_x_continuous(breaks = seq(0, 100, 10)) +
  scale_y_continuous(expand = c(0, 0)) +
  ggthemes::theme_tufte() +
  theme(legend.position = c(0.5, 0.9),
        plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 10))

p4 <- trans_periods %>% 
  mutate(churn = factor(churn, labels = c("Churn = 0", "Churn = 1"))) %>%
  ggplot() +
  aes(x = pct_active7, fill = churn) +
  geom_density(alpha = 0.4, color = "white") +
  labs(x = "", y = "Density", title = "% of active days in the last 7 days of subscr.", fill = "") +
  scale_x_continuous(breaks = seq(0, 100, 10)) +
  scale_y_continuous(expand = c(0, 0)) +
  ggthemes::theme_tufte() +
  theme(legend.position = c(0.5, 0.9),
        plot.title = element_text(face = "bold"),
        axis.text.x = element_text(size = 10))

ggarrange(p1,p2,p3,p4, nrow = 2, ncol = 2)

– Subscriptions ended in churn are a bit less active - as overall, as in the last days of subscription.


The problem with all described user logs characteristics is that they are time dependent, i.e. they change with time of subscription. As a consequence, we cannot use them in basic statistic or ML prediction models without special options and preliminary data restructuring. We also cannot use some average indicators over the whole subscription period, because for validation and test sets, as well as for any new data, we do not know the (future) length of subscription and (future) user’s behaviour, while we would like making predictions at any point of subscription (either start or in process) with all information available prior to that moment.

In addition, we cannot plot Kaplan-Meier curves for time-dependent predictors - instead we should use Simon and Makuch plots. Although, as for Kaplan-Meier curves they are applicable only for categorical variables.


4 Predicting survival distributions for subscriptions


We have splitted out data for subscription periods into train and test sets with 70% of all periods in train.


4.1 Features


Here we start from baseline models without time-dependent covariates and will use only interpretable models or, at least, those which allow estimation of variable importance, so a company will be able to understand what could be a cause for user’s attrition.

From all available features for each subscription period we chose the following:

  • churned_before_num - number of previous churns (set to 0 for the first subscription period of the corresponding user),

  • cityG - all unique numeric codes, with 19 and 20 codes grouped into one category,

  • bdG - age splitted into 2 categories: less than 10 and 10+ (years),

  • gender - female, male, unknown,

  • registered_via - all unique numeric codes without any grouping,

  • month_period - month at which subscription was started,

  • time_from_reg - time from registration at the service till the start of subscription (days),

  • first_payment_method_id - payment method used in the first transaction (at the start of subscription period),

  • first_plan_list_price - plan list price stated in the first transaction (at the start of subscription period),

  • time_from_prev_exp - time from the expiration date of the previous subscription (days, set to 0 for the first subscription of the corresponding user),

  • prev_time - sum of all previous subscription durations (days, set to 0 for the first subscription of the corresponding user),

  • prev_pct_active - share of active days to the total length of all previous subscriptions (in %, set to 0 for the first subscription of the corresponding user),

  • prev_mean_num_tot - average number of songs listened per day for all previous subscriptions (set to 0 for the first subscription of the corresponding user),

  • prev_mean_total_secs_imp - average number of seconds listened per day for all previous subscriptions (set to 0 for the first subscription of the corresponding user; we used values of daily seconds with imputed large values as it was described above),

  • prev_mean_num_unq - average number of unique songs listened per day for all previous subscriptions (set to 0 for the first subscription of the corresponding user),

  • prev_mean_av_pct_played_act - average % of one song’s length played per active day for all previous subscriptions (set to 0 for the first subscription of the corresponding user),

  • prev_mean_actual_amount_paid - mean (daily) actual amount paid for all previous subscriptions (set to 0 for the first subscription of the corresponding user),

  • prev_num_cancel - total number of cancellations for all previous subscriptions (set to 0 for the first subscription of the corresponding user; remember that cancellations include changes in subscription plans as well),

  • prev_mode_payment_id - modal value for payment method for all previous subscriptions (first subscription of the corresponding user was placed in the separate category of this feature; all categories with the number of subscriptions less than 30 were combined into one group with label “999”).

We also tried models with lag values instead of values for all previous subscruptions - on average, they perfomed worse, so we sticked to data on all previous subscriptions.


For models with time-dependent covariates we preliminary restructured our data set. Each subscription period was splitted into several intervals by the dates of users logs or transactions inside them (if they were there; otherwise a period remained unchanged). For each interval we were able to use all information about users, their transactions and user log statistics, but for comparability with models with time-independent covariates we chose the following features (if it is not denoted that the feature was time-independent, it was time-dependent):

  • churned_before_num (time-independent) - number of previous churns (set to 0 for the first subscription period of the corresponding user),

  • cityG (time-independent) - numeric code for city, with 19 and 20 codes grouped into one category,

  • bdG (time-independent) - age splitted into 2 categories: less than 10 and 10+ (years),

  • gender (time-independent) - female, male, unknown,

  • registered_via (time-independent) - all unique numeric codes without any grouping,

  • month_period (time-independent) - month at which subscription was started,

  • time_from_reg (time-independent) - time from registration at the service till the start of subscription (days),

  • time_from_prev_exp (time-independent) - time from the expiration date of the previous subscription (days, set to 0 for the first subscription of the corresponding user),

  • prev_time (time-independent) - sum of all previous subscription durations (days, set to 0 for the first subscription of the corresponding user),

  • pct_active - share of active (with logs) days to the total length of the interval (in %),

  • mean_num_tot - average number of songs listened per day during the interval,

  • mean_total_secs_imp - average number of seconds listened per day during the interval,

  • mean_num_unq - average number of unique songs listened per day during the interval,

  • mean_av_pct_played_act - average % of one song’s length played per active day during the interval,

  • cum_mean_actual_amount_paid - cumulative mean (daily) actual amount paid from the start of subscription to the end of this interval,

  • cum_num_cancel - cumulative number of cancellations from the start of subscription to the end of this interval,

  • last_payment_method_id - last value for payment method from the start of subscription to the end of this interval,

  • last_plan_list_price - last value for plan list price from the start of subscription to the end of this interval.


4.2 Models


I have chosen Cox PH and RSF as models for estimation, because:

  • they allow to predict survival functions,

  • there are packages in R which allow estimation and prediction of these models with time-dependent covariates,

  • in preliminary analysis they outperformed all other models (parametric PH, AFT, xgboost) for time-indepdenent covariates,

  • they are quite different in their assumptions - first of all, RSF does not assume proportionality in hazards, while, for example, xgboost does - thus, we should obtain predictions from two different assumptions fields.

Maximum depth and minimum leaf size for RSF were tuned using grid search with holdout, i.e. train sample was splitted into train and validation sets in proportion 8:2, and performance with different sets of parameters were estimated using Harrell’s C index on the validaion set. 500 trees were estimated (100 trees gave almost the same result).

For overall performance estimation and model comparison (for models with time-indepdendent covariates only!) I used the following metrics:

  • Harrell’s C index,

  • Uno’s C index (in comparison with Harrell’s one it incorporates weights from estimation of censoring distribution),

  • integrated Brier score (IBS) for the period from 1 to 720 days (~ 2 years),

  • time-dependent Brier scores for several time points: 7, 14, 30, 60, 90, 180, 360 and 720 days to estimate how prediction quality changes with time.

For models with time-dependent covariates I used Harrel’s C-index estimated both on train and test sets and IBS for the period from 1 to 720 days estimated on test only.


4.3 Instruments


I’ve used a mix of Python and R for all the project.

From R I used the following packages:

  • survival (Terry M. Therneau 2021) - for Kaplan-Meier estimators (above), Cox regression with time-dependent covariates, predicting survival probabilities and estimating Harrell’C after that,

  • survminer (Kassambara et al. 2021) - for plotting Kaplan-Meier curves (above),

  • bshazard (Rebora, Salim, and Reilly 2014) - for non-parametric estimation and plotting hazard functions (above),

  • rms (Harrell 2021) - for estimating variable importance in Cox regressions by proportion of explainable log-likelihood,

  • LTRCforest (Yao et al. 2021) - for estimating IBS in models with time-dependent covariates and for RSF with time-dependent covariates.


From Python I used the following libraries:

  • scikit-learn(Pedregosa et al. 2011) - for train/test split, grid search CV (tuning hyperparameters for RSF with time-independent covariates),

  • scikit-survival (Pölsterl 2020) - for preprocessing (one-hot encoding), estimating Cox PH and RSF models with time-independent covariates, predicting survival after them and estimates of all evaluation metrics.


4.4 Results


4.4.1 Models with time-independent covariates


X <- trans_periods %>%
  transmute(churned_before_num, 
            cityG = fct_lump_min(city, 30, other_level = "19-20"), 
            bdG, gender, registered_via, month_period,
            time_from_reg, time_from_prev_exp, 
            first_payment_method_id, first_plan_list_price,
            prev_time, prev_pct_active = 100*prev_pct_active, 
            prev_mean_num_tot, prev_mean_total_secs_imp,
            prev_mean_num_unq, prev_mean_av_pct_played_act,
            prev_mean_actual_amount_paid, prev_num_cancel,
            prev_mode_payment_id)

X_base <- X %>% select(-contains("prev_"))

Y_clmns <- trans_periods %>%
  select(churn, time)

sample_X_test_id <- read.csv("sample_X_test_id.csv")

cat_columns <- map_dfc(X, ~ !is.numeric(.x)) %>% t()
num_columns <- map_dfc(X, ~is.numeric(.x)) %>% t()
cat_columns <- names(X)[cat_columns[,1]]
num_columns <- names(X)[num_columns[,1]]

cat_base_columns <- map_dfc(X_base, ~ !is.numeric(.x)) %>% t()
num_base_columns <- map_dfc(X_base, ~is.numeric(.x)) %>% t()
cat_base_columns <- names(X_base)[cat_base_columns[,1]]
num_base_columns <- names(X_base)[num_base_columns[,1]]
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.inspection import permutation_importance
from sklearn.base import TransformerMixin
from sklearn.model_selection import train_test_split, GridSearchCV, KFold
from sksurv.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sksurv.nonparametric import kaplan_meier_estimator
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.ensemble import RandomSurvivalForest
from sksurv.metrics import concordance_index_censored, concordance_index_ipcw, brier_score, integrated_brier_score, as_concordance_index_ipcw_scorer

# Function to get predictions from estimator

def get_predictions(Xtrain, Xtest, ytrain, ytest, estimator):
  pred_train = estimator.predict(Xtrain)
  pred_test = estimator.predict(Xtest)
  pred_surv_train = estimator.predict_survival_function(Xtrain)
  pred_surv_test = estimator.predict_survival_function(Xtest)
  
  return pred_train, pred_test, pred_surv_train, pred_surv_test

# Function to get metrics for model

def get_metrics_df(Xtrain, Xtest, ytrain, ytest, estimator):
  pred_train = estimator.predict(Xtrain)
  pred_test = estimator.predict(Xtest)
  pred_surv_train = estimator.predict_survival_function(Xtrain)
  pred_surv_test = estimator.predict_survival_function(Xtest)

  hc_train = concordance_index_censored(ytrain["Churn"], ytrain["Time"], pred_train)[0]
  hc_test = concordance_index_censored(ytest["Churn"], ytest["Time"], pred_test)[0]
  uc_train = concordance_index_ipcw(ytrain, ytrain, pred_train)[0]
  uc_test = concordance_index_ipcw(ytest, ytest, pred_test)[0]
  
  times = np.array([7, 14, 30, 60, 90, 180, 360, 720])
  preds_train = np.asarray([[fn(t) for t in times] for fn in pred_surv_train])
  preds_test = np.asarray([[fn(t) for t in times] for fn in pred_surv_test])

  bs_train = brier_score(ytrain, ytrain, preds_train, times)
  bs_test = brier_score(ytest, ytest, preds_test, times)
  
  times2 = np.arange(1, 720)
  #preds_train2 = np.asarray([[fn(t) for t in times2] for fn in pred_surv_train])
  preds_test2 = np.asarray([[fn(t) for t in times2] for fn in pred_surv_test])
  
  #ibs_train = integrated_brier_score(ytrain, ytrain, preds_train2, times2)
  ibs_test = integrated_brier_score(ytest, ytest, preds_test2, times2)
  
  results = pd.DataFrame(
    {"Metric": ['Harrel`s C', 'Uno`s C', 'IBS 1-720d', 'BS 7d', 'BS 14d', 'BS 30d', 'BS 60d',
      'BS 90d', 'BS 180d', 'BS 360d', 'BS 720d'],
      "Train": [hc_train, uc_train, 'NULL'] + [i for i in bs_train[1]],
      "Test": [hc_test, uc_test, ibs_test] + [i for i in bs_test[1]]})
  
  return results, bs_train, bs_test

# Function to plot time-dependent Brier score

def plot_td_brier(bs_tuple):
  plt.figure()
  plt.plot([i for i in bs_base_test[0]], [i for i in bs_base_test[1]], marker="o")
  plt.xlabel("Days from start of subscription")
  plt.ylabel("Time-dependent Brier score")
  plt.xticks([7,30,60,90,180,360,720], [7,30,60,90,180,360,720])
  plt.grid()
  plt.show()
X_base = r.X_base
X_base.shape

X = r.X
X.shape

cat_columns = r.cat_columns
num_columns = r.num_columns
cat_base_columns = r.cat_base_columns
num_base_columns = r.num_base_columns

aux = [tuple([bool(i[0]), i[1]]) for i in r.Y_clmns.to_numpy()]
y = np.array(aux, dtype=[('Churn', '?'), ('Time', '<f8')])
y
y.shape

random_state = 42

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=random_state)
X_base_train, X_base_test, _, _ = train_test_split(X_base, y, test_size=0.3, random_state=random_state)

sample_X_test_id = r.sample_X_test_id
sample_X_test = X_test.loc[sample_X_test_id['ID'].values]
sample_X_base_test = X_base_test.loc[sample_X_test_id['ID'].values]

sample_y_churn = [int(i[0]) for i in y[sample_X_test_id['ID'].values]]
sample_y_time = [int(i[1]) for i in y[sample_X_test_id['ID'].values]]

sample_sort = np.argsort(sample_y_time)
sample_sort = sample_sort[::-1]

sample_X_test = sample_X_test.iloc[sample_sort]
sample_X_base_test = sample_X_base_test.iloc[sample_sort]

sample_y_churn = [sample_y_churn[i] for i in sample_sort]
sample_y_time = [sample_y_time[i] for i in sample_sort]

feature_names = X_train.columns.tolist()
feature_names_base = X_base_train.columns.tolist()
# Baseline Cox without data on previous subscription

cox_ph_base = Pipeline(
  [
    ('encode', OneHotEncoder()),
    ("model", CoxPHSurvivalAnalysis())
  ]
)

encoder_names = OneHotEncoder()

cox_ph_base.fit(X_base_train, y_train)

cox_ph_base_coef = pd.Series(cox_ph_base.named_steps['model'].coef_, index=encoder_names.fit_transform(X_base_train).columns)
cox_ph_base_coef = cox_ph_base_coef.reset_index()
cox_ph_base_coef.columns = ['Variable', 'Coef']
cox_ph_base_coef.sort_values(by='Coef', key=abs, ascending=False, inplace=True)
cox_ph_base_coef['HR'] = cox_ph_base_coef['Coef'].apply(lambda x: np.exp(x))
cox_ph_base_coef

cox_base_pred_train, cox_base_pred_test, cox_base_surv_train, cox_base_surv_test = get_predictions(X_base_train, X_base_test, y_train, y_test, cox_ph_base)

cox_base_results, bs_base_train, bs_base_test = get_metrics_df(X_base_train, X_base_test, y_train, y_test, cox_ph_base)

cox_base_results

plot_td_brier(bs_base_test)

pred_surv = cox_ph_base.predict_survival_function(sample_X_base_test)

time_points = np.arange(1, 720)
plt.figure()
for i, surv_func in enumerate(pred_surv):
    plt.step(time_points, surv_func(time_points), where="post",
             label="Churn={churn}, Time={time}d".format(churn=sample_y_churn[i], time=sample_y_time[i]))
plt.ylabel("Est. probability of survival $\hat{S}(t)$")
plt.xlabel("Time $t$")
plt.title('Cox PH baseline')
plt.legend(bbox_to_anchor=(0.7, 0.7))
plt.show()


# Cox with data on previous subscription

cox_ph = Pipeline(
  [
    ('encode', OneHotEncoder()),
    ("model", CoxPHSurvivalAnalysis())
  ]
)

cox_ph.fit(X_train, y_train)

cox_ph_coef = pd.Series(cox_ph.named_steps['model'].coef_, index=encoder_names.fit_transform(X_train).columns)
cox_ph_coef = cox_ph_coef.reset_index()
cox_ph_coef.columns = ['Variable', 'Coef']
cox_ph_coef.sort_values(by='Coef', key=abs, ascending=False, inplace=True)
cox_ph_coef['HR'] = cox_ph_coef['Coef'].apply(lambda x: np.exp(x))
cox_ph_coef

cox_pred_train, cox_pred_test, cox_surv_train, cox_surv_test = get_predictions(X_train, X_test, y_train, y_test, cox_ph)

cox_results, bs_train, bs_test = get_metrics_df(X_train, X_test, y_train, y_test, cox_ph)

cox_results

plot_td_brier(bs_test)


pred_surv = cox_ph.predict_survival_function(sample_X_test)

time_points = np.arange(1, 720)
plt.figure()
for i, surv_func in enumerate(pred_surv):
    plt.step(time_points, surv_func(time_points), where="post",
             label="Churn={churn}, Time={time}d".format(churn=sample_y_churn[i], time=sample_y_time[i]))
plt.ylabel("Est. probability of survival $\hat{S}(t)$")
plt.xlabel("Time $t$")
plt.title('Cox PH prev')
plt.legend(bbox_to_anchor=(0.7, 0.7))
plt.show()
# Train and validation sets

X_base_tt, X_base_val, y_tt, y_val = train_test_split(X_base_train, y_train, test_size=0.2, random_state=random_state)
X_tt, X_val, _, _ = train_test_split(X_train, y_train, test_size=0.2, random_state=random_state)

# RSF tuning

rsf_base_scores = []

for md in np.arange(5, 11, dtype=int):
  rsf = RandomSurvivalForest(n_estimators=500, max_depth=md, max_features="sqrt", min_samples_leaf=10, min_samples_split=10, n_jobs=-1, random_state=random_state)
  rsf_fit = rsf.fit(X_base_tt, y_tt)
  rsf_base_scores += [rsf.score(X_base_val, y_val)]
  print md, rsf_base_scores[-1]

for msl in np.array([5,10,15,20], dtype=int):
  rsf = RandomSurvivalForest(n_estimators=500, max_depth=14, max_features="sqrt", min_samples_leaf=msl, min_samples_split=10, n_jobs=-1, random_state=random_state)
  rsf_fit = rsf.fit(X_base_tt, y_tt)
  rsf_base_scores += [rsf.score(X_base_val, y_val)]
  print(msl, rsf_base_scores[-1])
  

# RSF with baseline data

rsf_base = Pipeline(
  [
    ('encode', OneHotEncoder()),
    ("model", RandomSurvivalForest(n_estimators=500, max_depth=14, max_features="sqrt", min_samples_leaf=5, min_samples_split=10, n_jobs=-1, random_state=random_state))
  ]
)

rsf_base.fit(X_base_train, y_train)

rsf_base_pred_train, rsf_base_pred_test, rsf_base_surv_train, rsf_base_surv_test = get_predictions(X_base_train, X_base_test, y_train, y_test, rsf_base, rf=True)

rsf_base_results, rsf_base_bs_train, rsf_base_bs_test = get_metrics_df(X_base_train, X_base_test, y_train, y_test, rsf_base)

rsf_base_results

pred_surv = rsf_base.predict_survival_function(sample_X_base_test)

time_points = np.arange(1, 720)
plt.figure()
for i, surv_func in enumerate(pred_surv):
    plt.step(time_points, surv_func(time_points), where="post",
             label="Churn={churn}, Time={time}d".format(churn=sample_y_churn[i], time=sample_y_time[i]))
plt.ylabel("Est. probability of survival $\hat{S}(t)$")
plt.xlabel("Time $t$")
plt.legend(bbox_to_anchor=(0.7, 0.7))
plt.title('RSF baseline')
plt.show()

rsf_imp_base = permutation_importance(rsf_base, X_base_test, y_test, n_repeats=10, random_state=random_state)
sorted_idx = rsf_imp_base.importances_mean.argsort()

fig, ax = plt.subplots()
ax.boxplot(
    rsf_imp_base.importances[sorted_idx].T, vert=False, labels=X_base_train.columns[sorted_idx]
)
ax.set_title("RSF baseline, variable importance (permutation)")
fig.tight_layout()
plt.show()


# RSF with data on previous subscriptions

rsf = Pipeline(
  [
    ('encode', OneHotEncoder()),
    ("model", RandomSurvivalForest(n_estimators=500, max_depth=14, max_features="sqrt", min_samples_leaf=5, min_samples_split=10, n_jobs=-1, random_state=random_state))
  ]
)

rsf.fit(X_train, y_train)

rsf_pred_train, rsf_pred_test, rsf_surv_train, rsf_surv_test = get_predictions(X_train, X_test, y_train, y_test, rsf, rf=True)

rsf_results, rsf_bs_train, rsf_bs_test = get_metrics_df(X_train, X_test, y_train, y_test, rsf)

rsf_results

pred_surv = rsf.predict_survival_function(sample_X_test)

time_points = np.arange(1, 720)
plt.figure()
for i, surv_func in enumerate(pred_surv):
    plt.step(time_points, surv_func(time_points), where="post",
             label="Churn={churn}, Time={time}d".format(churn=int(y_test[sample_mask.index[i]][0]), time=int(y_test[sample_mask.index[i]][1])))
plt.ylabel("Est. probability of survival $\hat{S}(t)$")
plt.xlabel("Time $t$")
plt.legend(bbox_to_anchor=(0.7, 0.7))
plt.title('RSF prev')
plt.show()

rsf_imp = permutation_importance(rsf, X_test, y_test, n_repeats=10, random_state=random_state)
sorted_idx = rsf_imp.importances_mean.argsort()

fig, ax = plt.subplots()
ax.boxplot(
    rsf_imp.importances[sorted_idx].T, vert=False, labels=X_train.columns[sorted_idx]
)
ax.set_title("RSF prev, variable importance (permutation)")
fig.tight_layout()
plt.show()

We estimated Cox PH and RSF models in two variants both:

  • with baseline variables only, i.e. with such features which are known at the start of subscription (member’s characteristics and data on his first transaction, from which subscription started), without any information about previous history of this user, except the number of previous churns (on the diagrams and in the tables below these models go with the index ‘baseline’),

  • with baseline characteristics and features concerning all previous subscriptions (on the diagrams and in the tables below these models go with the index ‘prev’).

Overall, RSF is a bit better in prediction performance measured by C-indices (Harrell’s C does not take censoring distribution into account, while Uno’s C does) than Cox PH. And both of them performs a bit better with features on previous subscriptions, although for RSF this gain is almost unnoticeable. In addition, both models give quire comparable results on train and test sets:

metrics <- list(CoxPH_base = read.csv("Cox_base.csv", sep=";"),
                CoxPH = read.csv("Cox.csv", sep=";"),
                RSF_base = read.csv("rsf_base.csv", sep=";"),
                RSF = read.csv("rsf.csv", sep=";"))

cindex <- lapply(metrics, function(x) x[1:2,])
cindex <- lapply(cindex, function(x) {x[,2] <- as.numeric(x[,2]) 
  x})
cindex <- do.call(dplyr::bind_rows, cindex)

cindex$model <- rep(c("Cox PH baseline", "Cox PH prev", "RSF baseline", "RSF prev"), each=2)
cindex <- rbind(cindex[,c(1:2,4)] %>% rename(value=Train) %>% mutate(set="Train"), 
                cindex[,c(1,3:4)] %>% rename(value=Test) %>% mutate(set="Test"))
cindex$Metric <- gsub("Harrel", "Harrell", cindex$Metric)

ggplot() +
  geom_bar(aes(x=Metric, y=value, fill=model), cindex,
           stat = "identity", 
           position = position_dodge(0.8), color = "white") +
  geom_text(aes(x=Metric, y=y, label=lbl, group=model), 
            cindex %>% mutate(y = value + 0.03, lbl = round(value,2)),
            position = position_dodge(0.8), stat = "identity",
            size = 3, fontface = "bold") +
  facet_wrap(~ set) +
  scale_y_continuous(expand = c(0,0), limits = c(0,1)) +
  scale_fill_manual(values = paletteer_d("ggthemes::Tableau_20")[c(2,1,4,3)]) +
  labs(x="", y="", fill="", title="C-index", color="") +
  ggthemes::theme_tufte() +
  theme(plot.title = element_text(face = "bold", hjust = 0.5),
        axis.ticks.x = element_blank(),
        strip.text = element_text(size = 12, face = "bold"),
        axis.line = element_line(size = .2, color = "grey20"),
        axis.text.x = element_text(size = 12, face = "bold"),
        strip.background = element_rect(fill = NA))


For the integrated Brier score (IBS), calculated for the period from 1 to 720 days the disposition is the same:

ibs <- unlist(lapply(metrics, function(x) x[3,3]))
ibs <- tibble(model = c("Cox PH baseline", "Cox PH prev", "RSF baseline", "RSF prev"),
              value = ibs)

ggplot(ibs) +
  geom_bar(aes(x=model, y=value, fill=model), stat = "identity", width = 0.6) +
  geom_text(aes(x=model, y=value-0.01, label=round(value,3)),
            size = 4, color = "grey10", fontface = "bold") +
  scale_y_continuous(expand = c(0,0), limits = c(0,0.16),
                     breaks = seq(0,0.15,0.05)) +
  scale_fill_manual(values = paletteer_d("ggthemes::Tableau_20")[c(2,1,4,3)]) +
  labs(x="", y="", title="IBS, test sample") +
  coord_flip() +
  ggthemes::theme_tufte() +
  theme(legend.position = "none", 
        plot.title = element_text(face = "bold", hjust = 0.5),
        strip.text = element_text(size = 12, face = "bold"),
        axis.line = element_line(size = .2, color = "grey20"),
        axis.text.y = element_text(size = 12, hjust = 0,
                                   color = "grey10", face = "bold",
                                   margin = margin(l = 10, r = -100)),
        axis.ticks.y = element_blank())


But if we turn our attention to the time-dependent Brier score (we calculated it for 7, 14, 30, 60, 90, 180, 360 and 720 days after the start of subscription), then advantages of RSF models over Cox PH are more prominent for intermediate time points:


We sampled 3 subscriptions out of the test sample from both the group of subscriptions actually ended by churn and censored, and predict survival functions for them using all the 4 models above:

Again, there are mode differences in predictions obtained from two Cox models, than between two RSF models. Above all, it could be because of violation of the PH assumption.


For Cox regression variable importance can be estimated using the proportion of log-likelihood that is explained by each variable (for categorical variables it comprises all their categories into so called grouped variable). The plot below displays the Wald \(\chi^2\) statistic minus its degrees of freedom for assessing the partial effect of each variable. According to Harrell (2021), even though this is not a scaled [0,1] statistic, it is probably the best method in general because it penalizes a variable requiring a large number of parameters to achieve the \(\chi^2\).

df_feat <- trans_periods %>%
  transmute(subscr_id, churn, time,
            churned_before_num, 
            cityG = fct_lump_min(city, 30, other_level = "19-20"), 
            bdG, gender, registered_via, month_period,
            time_from_reg, time_from_prev_exp, 
            first_payment_method_id, first_plan_list_price,
            prev_time, prev_pct_active = 100*prev_pct_active, 
            prev_mean_num_tot, prev_mean_total_secs_imp,
            prev_mean_num_unq, prev_mean_av_pct_played_act,
            prev_mean_actual_amount_paid, prev_num_cancel,
            prev_mode_payment_id)

df_feat_base <- df_feat %>% select(-contains("prev_"))

feat_base <- names(df_feat_base)[-c(1:3)]
feat_prev <- names(df_feat)[-c(1:3)]

train_set <- read.csv("X_train_index.csv")
sample_X_test_id <- read.csv("sample_X_test_id.csv")

df_base_tr <- df_feat_base %>%
  filter(subscr_id %in% trans_periods$subscr_id[train_set$ind + 1])

df_base_tst <- df_feat_base %>%
  filter(! subscr_id %in% trans_periods$subscr_id[train_set$ind + 1])

df_tr <- df_feat %>%
  filter(subscr_id %in% trans_periods$subscr_id[train_set$ind + 1])

df_tst <- df_feat %>%
  filter(! subscr_id %in% trans_periods$subscr_id[train_set$ind + 1])

sample_X_subscr_id <- trans_periods$subscr_id[sample_X_test_id$ID + 1]

samp_base_test <- df_feat_base %>%
  filter(subscr_id %in% trans_periods$subscr_id[sample_X_test_id$ID + 1])
samp_test <- df_feat %>%
  filter(subscr_id %in% trans_periods$subscr_id[sample_X_test_id$ID + 1])

fml_base = as.formula(paste0("Surv(time = time, event = churn) ~ ",
                             paste(feat_base, collapse = " + ")))
fml_prev = as.formula(paste0("Surv(time = time, event = churn) ~ ",
                             paste(feat_prev, collapse = " + ")))

cox_base_rms <- cph(fml_base, df_base_tr, x=TRUE, y=TRUE)
cox_base <- coxph(fml_base, df_base_tr)
cox_base_anova_rms <- anova(cox_base_rms)
png("cox_base_imp.png", width = 500, height = 400)
plot(cox_base_anova_rms, sort = "ascending")
dev.off()

cox_rms <- cph(fml_prev, df_tr, x=TRUE, y=TRUE)
cox <- coxph(fml_prev, df_tr)
cox_anova_rms <- anova(cox_rms)
png("cox_imp.png", width = 500, height = 400)
plot(cox_anova_rms, sort = "ascending")
dev.off()

p1 <- ggdraw() + draw_image("cox_base_imp.png", scale = 1)
p2 <- ggdraw() + draw_image("cox_imp.png", scale = 1)
ggarrange(p1,p2,nrow = 1, ncol = 2, 
          labels = c("Cox PH baseline, variable importance",
                     "Cox PH prev, variable importance"), hjust = -0.25)


For RSF models we calculated weights for variable importance using 10 permutations - these weights show the value by which the concordance index on the test data will drop, on average, if the relationship of the according feature to survival time is removed (by random shuffling).


You can see that all models consider the first payment method used as the most important feature, the first plan list price is the second by importance in RSF models, and the third one in Cox PH models. Month of the subscription origin is more important in baseline specifications compared to extended with characteristics of the previous subscriptions, and in Cox regressions compared to RSF. The most considerable difference in ranging features between Cox PH and RSF models is observed for the most frequent payment method used in previous subscriptions: it holds the second place in Cox regression and one of the lasts in RSF. Among other historical features average percent of one song’ length played matters the most.


To reveal how survival probability changes with different values of the most important (and other) variables we can predict and plot survival curves for them in particular, taking all other features equal. In addition, for Cox regression we can interpret exponentiated coefficients as hazard ratios: if hazard ratio > 1 (the corresponding coefficient is negative), then the this feature positively correlated with churn hazard, otherwise the correlation is negative. We won’t plot such survival curves here - just display plots for hazard ratios. All hazard ratios for categorical variables should be interpreted in comparison with baseline category of the corresponding variable.

feature_group <- function(x, feat) {
  for (i in 1:length(x)) {
    for (j in feat) {
      if (grepl(j, x[i])) {
        x[i] <- j
      }
    }
  }
  x
}

cox_base_coef <- summary(cox_base)
cox_base_coef <- tibble(Variable = row.names(cox_base_coef$coefficients),
                        group = feature_group(Variable, feat_base),
                        HR = cox_base_coef$coefficients[,2],
                        se = cox_base_coef$coefficients[,3]) %>%
  arrange(HR, group)

p1 <- ggplot(cox_base_coef, aes(x=Variable, y=HR, color=group)) + 
  geom_hline(yintercept = 1) +
  geom_pointrange(aes(ymin=HR*exp(-1.96*se), ymax=HR*exp(1.96*se))) +
  scale_x_discrete(limits = cox_base_coef$Variable) +
  scale_color_manual(values = rev(paletteer_d("ggthemes::Tableau_10")[1:9])) +
  coord_flip() +
  labs(y = "Hazard Ratio [95% CI]", x = "", title = "Cox PH baseline, hazard ratios") +
  ggthemes::theme_tufte() +
  theme(legend.position = "none",
        plot.title = element_text(face = "bold", hjust = 0.5),
        axis.line = element_line(size = .2, color = "grey20"),
        axis.ticks.y = element_blank(),
        panel.grid.major.x = element_line(size=.2, color="grey90"),
        panel.grid.minor.x = element_line(size=.2, color="grey90"),
        panel.grid.major.y = element_line(size=.2, color="grey90", linetype="dotted"))

cox_coef <- summary(cox)
cox_coef <- tibble(Variable = row.names(cox_coef$coefficients),
                        group = feature_group(Variable, feat_prev),
                        HR = cox_coef$coefficients[,2],
                        se = cox_coef$coefficients[,3]) %>%
  arrange(HR, group)

p2 <- ggplot(cox_coef, aes(x=Variable, y=HR, color=group)) + 
  geom_hline(yintercept = 1) +
  geom_pointrange(aes(ymin=HR*exp(-1.96*se), ymax=HR*exp(1.96*se))) +
  scale_x_discrete(limits = cox_coef$Variable) +
  scale_color_manual(values = paletteer_d("ggthemes::Classic_20")[1:19]) +
  coord_flip() +
  labs(y = "Hazard Ratio [95% CI]", x = "", title = "Cox PH prev, hazard ratios") +
  ggthemes::theme_tufte() +
  theme(legend.position = "none",
        plot.title = element_text(face = "bold", hjust = 0.5),
        axis.line = element_line(size = .2, color = "grey20"),
        axis.ticks.y = element_blank(),
        panel.grid.major.x = element_line(size=.2, color="grey90"),
        panel.grid.minor.x = element_line(size=.2, color="grey90"),
        panel.grid.major.y = element_line(size=.2, color="grey90", linetype="dotted"))

ggarrange(p1,p2,nrow = 1,ncol = 2)


4.4.2 Models with time-dependent covariates


trans_df_date <- trans_df %>%
  group_by(subscr_id, transaction_date) %>%
  summarise(date = unique(transaction_date),
            actual_amount_paid = sum(actual_amount_paid),
            is_cancel = sum(is_cancel), 
            payment_method_id = payment_method_id[n()],
            plan_list_price = max(plan_list_price)) %>%
  ungroup()

ulogs_trans_periods <- ulogs_periods %>%
  select(date, subscr_id, num_tot, num_unq, total_secs_imp, av_pct_played) %>%
  mutate(file = "ulogs") %>%
  full_join(trans_periods %>%
              select(subscr_id, msno, msno_subscr_id, churn, time, start_date, end_date,
                     churned_before_num, city, bdG, gender, registered_via, month_period,
                     first_payment_method_id, first_plan_list_price,
                     time_from_reg, time_from_prev_exp, prev_time,
                     prev_mean_num_tot, prev_mean_total_secs_imp,
                     prev_mean_num_unq, prev_mean_av_pct_played_act,
                     prev_mean_actual_amount_paid, prev_num_cancel,
                     prev_mode_payment_id),
            by = "subscr_id") %>%
  full_join(trans_df_date %>% transmute(subscr_id, 
                                        date = as.Date(ifelse(subscr_id %in% ulogs_periods$subscr_id,
                                                              transaction_date, 
                                                              NA), "1970-01-01"),
                                        transaction_date, actual_amount_paid, plan_list_price,
                                        is_cancel, payment_method_id),
            by = c("subscr_id", "date"))

ulogs_trans_periods <- ulogs_trans_periods %>%
  mutate(date = ifelse(is.na(date), transaction_date, date)) %>%
  mutate_at(vars(num_tot, num_unq, total_secs_imp, av_pct_played,
                 actual_amount_paid, plan_list_price, is_cancel), ~ replace_na(., 0)) %>%
  group_by(subscr_id) %>%
  mutate_at(vars(msno, msno_subscr_id, churn, time, start_date, end_date,
                 churned_before_num, city, bdG, gender, registered_via, month_period,
                 first_payment_method_id, first_plan_list_price,
                 time_from_reg, time_from_prev_exp, prev_time,
                 prev_mean_num_tot, prev_mean_total_secs_imp,
                 prev_mean_num_unq, prev_mean_av_pct_played_act,
                 prev_mean_actual_amount_paid, prev_num_cancel,
                 prev_mode_payment_id),
            ~ unique(.[!is.na(.)])) %>%
  mutate(churn = ifelse(row_number() != n(), 0, churn))%>%
  ungroup()

ulogs_trans_periods <- ulogs_trans_periods %>%
  mutate(date = as.Date(date, "1970-01-01")) %>%
  arrange(subscr_id, date)

ulogs_trans_periods <- suppressWarnings(
  ulogs_trans_periods %>%
    group_by(subscr_id) %>%
    mutate(time_start = replace_na(as.numeric(date - date[1], 'days'), 0),
           time_stop = ifelse(is.na(lead(time_start)),
                              ifelse(is.na(date),
                                     time,
                                     as.numeric(end_date - date[1], 'days') + 1),
                              lead(time_start)),
           time_stop = ifelse(time_stop == 0, 1, time_stop),
           time_int = time_stop - time_start,
           cum_pct_active = cumsum(!is.na(file))/time_stop*100,
           cum_mean_num_tot = cumsum(num_tot)/time_stop,
           cum_mean_num_unq = cumsum(num_unq)/time_stop,
           cum_mean_total_secs_imp = cumsum(total_secs_imp)/time_stop,
           mean_av_pct_played_act = replace_na(cumsum(av_pct_played)/cumsum(!is.na(file)), 0),
           cum_mean_actual_amount_paid = cumsum(actual_amount_paid)/time_stop,
           cum_num_cancel = cumsum(is_cancel),
           last_payment_method_id = zoo::na.locf(payment_method_id),
           last_plan_list_price = zoo::na.locf(plan_list_price)) %>%
    ungroup())

saveRDS(trans_periods, "trans_periods.RDS")
saveRDS(ulogs_trans_periods, "ulogs_trans_periods.RDS")

train_set <- read.csv("X_train_index.csv")
sample_X_test_id <- read.csv("sample_X_test_id.csv")

df_features <- ulogs_trans_periods %>%
  transmute(subscr_id, time_start, time_stop, churn,
            churned_before_num, 
            cityG = fct_lump_min(city, 30, other_level = "19-20"), bdG, gender,
            registered_via, month_period,
            time_from_reg, time_from_prev_exp, prev_time, 
            cum_pct_active, cum_mean_num_tot, cum_mean_total_secs_imp,
            cum_mean_num_unq, mean_av_pct_played_act, cum_mean_actual_amount_paid,
            cum_num_cancel, last_payment_method_id, last_plan_list_price)

df_train <- df_features %>%
  filter(subscr_id %in% trans_periods$subscr_id[train_set$ind + 1])

df_test <- df_features %>%
  filter(! subscr_id %in% trans_periods$subscr_id[train_set$ind + 1])

sample_X_subscr_id <- trans_periods$subscr_id[sample_X_test_id$ID + 1]

sample_test <- df_features %>%
  filter(subscr_id %in% trans_periods$subscr_id[sample_X_test_id$ID + 1])

features_baseline <- c("churned_before_num", 
                       "cityG", "bdG", "gender",
                       "registered_via", "month_period",
                       "time_from_reg", "time_from_prev_exp", "prev_time")
features_timedep <- c("cum_pct_active", "cum_mean_num_tot", "cum_mean_total_secs_imp",
                      "cum_mean_num_unq", "mean_av_pct_played_act", "cum_mean_actual_amount_paid",
                      "cum_num_cancel", "last_payment_method_id", "last_plan_list_price")  

# Predict survival after Cox PH

survest_ind <- function(estimator, id_period) {
  intervals <- df_test %>%
    filter(subscr_id == id_period) %>%
    select(time_start, time_stop, churn)
  
  covs <- data.frame(df_test %>% filter(subscr_id == id_period) %>% select(all_of(features_baseline)) %>% unique())
  covs[, features_timedep] <- 0
  covs$last_payment_method_id <- trans_periods$first_payment_method_id[trans_periods$subscr_id == id_period]
  covs$last_plan_list_price <- trans_periods$first_plan_list_price[trans_periods$subscr_id == id_period]
  
  newdata <- data.frame(covs[1, ], intervals, row.names = NULL)
  out <- survfit(estimator, newdata = newdata)
  
  survests <- list(baseline = tibble(time = out$time, 
                                     surv = out$surv[,1]))
  
  covs <- df_test %>%
    filter(subscr_id == id_period)
  
  newdata <- data.frame(covs %>% select(-subscr_id), row.names = NULL)
  out <- survfit(estimator, newdata = newdata)
  
  survests[["complete"]] <- out$surv
  
  survests
}

# Plot updated and not updated survival curves at #num points

survest_plots <- function(idt, idsurv, num=10) {
  times <- df_test %>%
    filter(subscr_id == idt) %>%
    select(time_stop) %>%
    pull()
  
  if (length(times) >= num) {
    times_select <- seq_along(times)
    times_select <- unique(c(seq(1, max(times_select), floor(max(times_select)/(num-1))), max(times_select)))
    times_select <- times[times_select]
  } else {
    times_select <- times
  }
  
  times_select <- times_select[order(times_select)]
  
  times_lab <- c("Baseline", times_select)
  
  colors <- paletteer_c("grDevices::Zissou 1", 1 + length(times_select))
  
  dfp <- tibble(time = as.numeric(idsurv$baseline$time),
                surv = as.numeric(idsurv$baseline$surv), 
                lbl = sprintf("If not updated after %s%s", times_lab[1], ""))
  p <- ggplot(dfp) +
    geom_line(aes(x = time, y = surv, linetype = lbl),
              size = 1, color = colors[1]) +
    scale_y_continuous(breaks = seq(0,1,0.2),
                       labels = scales::percent_format(accuracy = 1),
                       expand = c(0.02,0),
                       limits = c(0, 1)) +
    scale_x_continuous(breaks = seq(0, 810, 90), expand = c(0.02,0)) +
    labs(x = "Days", y = "Exp. survival probability") +
    theme_classic() 
  plots <- list(p)
  
  surv_upd <- dfp$surv[1]
  
  for (i in seq_along(times_select)) {
    dfp <- tibble(time = as.numeric(idsurv$baseline$time[idsurv$baseline$time >= times_select[i]]),
                  surv = as.numeric(idsurv$complete[idsurv$baseline$time >= times_select[i], 
                                                    which(times == times_select[i])]),
                  lbl = sprintf("If not updated after %s%s", times_lab[i+1], "d."))
    surv_upd <- c(surv_upd, dfp$surv[1])
    
    p <- ggplot(dfp) +
      geom_line(aes(x = time, y = surv, linetype = lbl),
                size = 1, color = colors[i + 1]) +
      scale_y_continuous(breaks = seq(0,1,0.2), 
                         labels = scales::percent_format(accuracy = 1),
                         expand = c(0.02,0), limits = c(0, 1)) +
      scale_x_continuous(breaks = seq(0, 810, 90), expand = c(0.02,0),
                         limits = c(0, 821)) +
      labs(x = "Days", y = "Exp. survival probability", linetype = "") +
      theme_classic()
    
    plots[[i+1]] <- p
  }
  
  dfp <- tibble(time = c(0, times_select),
                surv = surv_upd)
  
  for (i in 1:nrow(dfp)) {
    plots[[i]] <- plots[[i]] +
      geom_line(aes(x = time, y = surv, linetype = "Updated"), dfp,
                size = 0.5, color = "grey20", show.legend = TRUE) +
      geom_point(aes(x = time, y = surv), 
                 dfp[i,], shape = 21,
                 size = 3, color = "black", fill = colors[i],
                 show.legend = FALSE) +
      scale_linetype_manual(values = c("solid", "dotted")) +
      labs(linetype = "") +
      guides(linetype = guide_legend(override.aes = list(color = c(colors[i], "grey50")))) +
      theme(legend.position = c(0.7,0.9),
            legend.background = element_blank())
    
    plots[[i]]$layers <- plots[[i]]$layers[c(2,1,3)]
    
  }
  
  plots
}

Here in the train sample we used all subintervals for those subscription periods which were a part of the training sample with time-independent covariates. The same holds for the test set. Overall, instead of 17747 subscription periods in time-invariant training data and 7606 in time-invariant test data, we obtained, accordingly, 2580019 and 1098763 pseudo-observations after counting process.

# Cox

cox_td <- coxph(Surv(time = time_start, time2 = time_stop, event = churn) ~ ., 
                df_train %>% select(-subscr_id), x = TRUE)

cox_td_sum <- summary(cox_td)

# Plot gifs for updated and not updated survival curves for sampled test periods
# + one additional for illustration

for (idt in c(sample_X_subscr_id, 16739)) {
  
  test <- survest_ind(cox_td, idt)
  test_plots <- survest_plots(idt, test)
  
  dir_out <- file.path(tempdir(), paste0("surv_plots/", idt))
  dir.create(dir_out, recursive = TRUE)
  
  for (i in 1:length(test_plots)) {
    
    fp <- file.path(dir_out, paste0("id", idt, "_", 100+i, ".png"))
    
    ggsave(plot = test_plots[[i]], 
           filename = fp, width = 5, height = 3, dpi = 300,
           device = "png")
  }
  
  imgs <- list.files(dir_out, full.names = TRUE)
  img_list <- lapply(imgs, image_read)
  img_list <- lapply(img_list, image_scale, "x600")
  
  img_joined <- image_join(img_list)
  
  img_animated <- image_animate(img_joined, fps = 1)
  
  image_write(image = img_animated,
              path = paste0("test_plots_", idt, ".gif"))
}  


cox_td_conc_train <- cox_td_sum$concordance[1]
cox_td_conc_test <- concordance(cox_td, newdata = data.frame(df_test))
cox_td_conc_test <- cox_td_conc_test$concordance


cox_td_pred <- survfit(cox_td, newdata = df_test, type = "aalen",
                       id = df_test$subscr_id)

cox_td_pred_df <- data.frame(time = cox_td_pred$time,
                             prob = cox_td_pred$surv)
cox_td_pred_df$id <- cox_td_pred_df$time - replace_na(lag(cox_td_pred_df$time),3)
cox_td_pred_df$id <- cumsum(cox_td_pred_df$id <= 0)

cox_td_pred_list <- split(cox_td_pred_df, cox_td_pred_df$id)
cox_td_pred_list <- setNames(cox_td_pred_list, NULL)

cox_td_pred_list_ibs720 <- lapply(cox_td_pred_list,
                                  function(x) {
                                    x <- x[x$time <= 720, ]
                                    x <- x[,-c(1,3)]
                                    c(1, x)
                                })

cox_td_for_ibs720 <- list(survival.probs = cox_td_pred_list_ibs720,
                          survival.times = 0:720,
                          survival.tau = pmin(taus, 720))

data <- data.frame(df_test)
Survobj = Surv(data$time_start, data$time_stop, data$churn)
cox_ibs_td_pred <- sbrier_ltrc(obj = Survobj, id = data$subscr_id,
                               pred = cox_td_for_ibs720, type = "IBS")
fml = as.formula(paste0("Surv(time = time_start, time2 = time_stop, event = churn, type = 'counting') ~ ",
                        paste(features_baseline, collapse = " + "),
                        " + ",
                        paste(features_timedep, collapse = " + ")))

rrf_fit <- ltrcrrf(formula = fml, ntree = 10, mtry = 6,
                   data = data.frame(df_train), 
                   id = subscr_id)
# test
rrf_pred <- predictProb(object = rrf_fit, 
                        newdata = data.frame(df_test),
                        newdata.id = subscr_id,
                        time.eval = 0:720)
data <- data.frame(df_test)
Survobj = Surv(data$time_start, data$time_stop, data$churn)
rrf_pred_ibs <- sbrier_ltrc(obj = Survobj, id = data$subscr_id, 
                            pred = rrf_pred, type = "IBS")
rrf_pred_bs <- sbrier_ltrc(obj = Survobj, id = data$subscr_id, 
                           pred = rrf_pred, type = "BS")

rrf_pred821 <- predictProb(object = rrf_fit, 
                           newdata = data.frame(df_test),
                           newdata.id = subscr_id,
                           time.eval = 0:821)
rrf_pred_prob <- tibble(subscr_id = rep(X_test_subscr_id, each = 822),
                        time_stop = rep(0:821, times = length(X_test_subscr_id)),
                        pred_prob = as.numeric(rrf_pred821$survival.probs))
rrf_pred_prob <- rrf_pred_prob %>%
  right_join(df_test %>% select(subscr_id, time_stop), by = c("subscr_id", "time_stop"))
data <- data.frame(df_test)
rrf_conc <- concordancefit(Surv(data$time_start, data$time_stop, data$churn), 
                           x = rrf_pred_prob %>% pull(pred_prob))

# train
rrf_pred821_tr <- predictProb(object = rrf_fit, 
                              newdata = data.frame(df_train),
                              newdata.id = subscr_id,
                              time.eval = 0:821)
rrf_pred_tr_prob <- tibble(subscr_id = rep(unique(df_train$subscr_id), each = 822),
                           time_stop = rep(0:821, times = length(unique(df_train$subscr_id))),
                           pred_prob = as.numeric(rrf_pred821_tr$survival.probs))
rrf_pred_tr_prob <- rrf_pred_tr_prob %>%
  right_join(df_train %>% select(subscr_id, time_stop), by = c("subscr_id", "time_stop"))
data <- data.frame(df_train)
rrf_conc_tr <- concordancefit(Surv(data$time_start, data$time_stop, data$churn), 
                              x = rrf_pred_tr_prob %>% pull(pred_prob))

rrf_pred720_tr <- rrf_pred821_tr[2:4]
rrf_pred720_tr$survival.probs <- rrf_pred720_tr$survival.probs[1:721,]
rrf_pred720_tr$survival.times <- rrf_pred720_tr$survival.times[1:721]

data <- data.frame(df_train)
Survobj = Surv(data$time_start, data$time_stop, data$churn)
rrf_pred_bs_tr <- sbrier_ltrc(obj = Survobj, id = data$subscr_id, 
                           pred = rrf_pred720_tr, type = "BS")

Cox PH with time-dependent covariates enabled increasing Harrell’s C-index to 0.878 on the train set and 0.876 on the test set (in comparison with 0.811 and 0.809, accordingly, for Cox PH with baseline variables and data on previous subscriptions), and decreasing IBS from 0.145 to 0.109 on the test set.

cindex_all <- rbind(cindex %>%
  filter(Metric == "Harrell`s C"),
  tibble(Metric = rep("Harrell`s C", 4),
         value = c(0.878,0.917,0.876,0.887),
         model = c("Cox PH time-dep", "RSF time-dep", "Cox PH time-dep", "RSF time-dep"),
         set = c("Train", "Train", "Test", "Test"))) %>%
  mutate(group = factor(grepl("RSF", model), labels = c("Cox PH", "RSF")),
         spec = factor(ifelse(grepl("line", model), 1,
                              ifelse(grepl("prev", model), 2, 3)),
                       1:3, c("Baseline", "Prev", "TimeDep")))

ggplot() +
  geom_bar(aes(x=spec, y=value, fill=model), cindex_all,
           stat = "identity", 
           position = position_dodge(0.8), color = "white") +
  geom_text(aes(x=spec, y=y, label=lbl, group=model), 
            cindex_all %>% mutate(y = value + 0.06, lbl = round(value,2)),
            position = position_dodge(0.8), stat = "identity",
            size = 4, fontface = "bold") +
  facet_rep_grid(set ~ group, switch = "y", repeat.tick.labels = TRUE) +
  scale_y_continuous(expand = c(0,0,0,0.02), limits = c(0,1)) +
  scale_fill_manual(values = paletteer_d("ggthemes::Hue_Circle")[c(17:19, 10:12)]) +
  labs(x="", y="", fill="", title="Harrell's C-index", color="") +
  ggthemes::theme_tufte() +
  theme(legend.position = "none",
        plot.title = element_text(face = "bold", hjust = 0.5),
        axis.ticks.x = element_blank(),
        strip.text = element_text(size = 12, face = "bold"),
        axis.line = element_line(size = .2, color = "grey20"),
        axis.text.x = element_text(size = 12, face = "bold", color = "black"),
        strip.background = element_rect(fill = NA),
        strip.placement = "outside",
        panel.spacing.y=unit(1, "lines"))

bs <- list(CoxPH_base = read.csv("cox_base_bs.csv", sep=";"),
           CoxPH = read.csv("cox_bs.csv", sep=";"),
           RSF_base = read.csv("rsf_base_bs.csv", sep=";"),
           RSF = read.csv("rsf_bs.csv", sep=";"))
bs <- lapply(bs, 
             function(x) {
               x[,2] <- as.numeric(gsub(",",".", x[,2]))
               x[,3] <- as.numeric(gsub(",",".", x[,3]))
               x})
bs <- do.call(dplyr::bind_rows, bs)

bs$model <- rep(c("Cox PH baseline", "Cox PH prev", "RSF baseline", "RSF prev"), each=8)
bs <- rbind(bs[,c(1:2,4)] %>% rename(value=Train) %>% mutate(set="Train"), 
                bs[,c(1,3:4)] %>% rename(value=Test) %>% mutate(set="Test"))

bs_rsf_td <- read.csv("rrf_td_bs.csv")
bs_rsf_td <- rbind(bs_rsf_td %>%
                     filter(Time %in% c(7,14,30,60,90,180,360,720)) %>%
                     rename(value = BScore) %>%
                     mutate(model = "RSF time-dep", set = "Test") %>%
                     select(-rrf_pred_bs_tr...2.),
                   bs_rsf_td %>%
                     filter(Time %in% c(7,14,30,60,90,180,360,720)) %>%
                     rename(value = rrf_pred_bs_tr...2.) %>%
                     mutate(model = "RSF time-dep", set = "Train") %>%
                     select(-BScore))
                   
bs_all <- rbind(bs, bs_rsf_td)  %>%
  mutate(group = factor(grepl("RSF", model), labels = c("Cox PH", "RSF")),
         spec = factor(ifelse(grepl("line", model), 1,
                              ifelse(grepl("prev", model), 2, 3)),
                       1:3, c("Baseline", "Prev", "TimeDep")))

ggplot() +
  geom_line(aes(x=Time, y=value, color=model), bs_all, size = 1) +
  geom_point(aes(x=Time, y=value, color=model), bs_all, size = 2, show.legend = FALSE) +
  facet_rep_grid(~ set, switch = "y", repeat.tick.labels = TRUE) +
  scale_y_continuous(expand = c(0,0), limits = c(0, 0.18)) +
  scale_x_continuous(expand = c(0.02,0), breaks = seq(0,720,90)) +
  scale_color_manual(values = paletteer_d("ggthemes::Hue_Circle")[c(17:18, 10:12)]) +
  labs(x="Days", y="", fill="", title="Time-dependent Brier-score", color="") +
  guides(color = guide_legend(override.aes = list(size = 5))) +
  ggthemes::theme_tufte() +
  theme(legend.position = "bottom",
        plot.title = element_text(face = "bold", hjust = 0.5),
        strip.text = element_text(size = 12, face = "bold"),
        axis.line = element_line(size = .2, color = "grey20"),
        axis.text.x = element_text(size = 10),
        strip.background = element_rect(fill = NA),
        strip.placement = "outside",
        panel.spacing.y=unit(1, "lines"))


The other advantage of the models with time-dependent covariates is that they allow to update survival predictions after the start of subscription period using all new information on user’s behaviour.

For example, using results for Cox PH regression with time-dependent covariates we can estimate “updated”survival curve - it reflects survival distribution predicted using all data obtained on this subscription up to the end of the corresponding subscription period. In addition we predicted survival curve at the zero point (baseline), assuming that a user had just subscribed to the service and had not yet listened to music, so we have only member’s characteristics and data on his/her first transaction, number of previous churns and time from previous expiration date. After that we took the first time interval in subscription and several intermediate points, and predict survival with all data on user’s behavior obtained up to that point. For baseline and other time point “future” parts of predicted survival curves are shown. For the last available for that subscription time point a “future” part is a prolongation of the ‘updated’ survival curve.

Example of updating survival curve for a random user


You can see that with new data being obtained predicted survival curve is approaching to the “updated” one, while at the beginning it could be too pessimistic or too optimistic.

Also we can plot survival curves not from zero time point, but from any time conditional on surviving to this point, so we can predict survival for some next period out of 100% of those customers who remained active. For example, for the first of above users it could be like the following:

newdata <- df_test %>%
  filter(subscr_id == 16739)
newdata <- data.frame(newdata %>% select(-subscr_id), row.names = NULL)
out <- survfit(cox_td, newdata = newdata, se.fit = F, start.time = 0)
out30 <- survfit(cox_td, newdata = newdata, se.fit = F, start.time = 30)
out60 <- survfit(cox_td, newdata = newdata, se.fit = F, start.time = 60)
out90 <- survfit(cox_td, newdata = newdata, se.fit = F, start.time = 90)
out180 <- survfit(cox_td, newdata = newdata, se.fit = F, start.time = 180)
out270 <- survfit(cox_td, newdata = newdata, se.fit = F, start.time = 270)
out360 <- survfit(cox_td, newdata = newdata, se.fit = F, start.time = 360)
out450 <- survfit(cox_td, newdata = newdata, se.fit = F, start.time = 450)
out540 <- survfit(cox_td, newdata = newdata, se.fit = F, start.time = 540)

ggplot() +
  geom_line(aes(x = c(0, out$time), y = c(1, out$surv[,714]), color = "0"), size = 1) +
  geom_line(aes(x = c(30, out30$time), y = c(1, out30$surv[,714]), color = "30d"), size = 1) +
  geom_line(aes(x = c(60, out60$time), y = c(1, out60$surv[,714]), color = "60d"), size = 1) +
  geom_line(aes(x = c(90, out90$time), y = c(1, out90$surv[,714]), color = "90d"), size = 1) +
  geom_line(aes(x = c(180, out180$time), y = c(1, out180$surv[,714]), color = "180d"), size = 1) +
  geom_line(aes(x = c(270, out270$time), y = c(1, out270$surv[,714]), color = "270d"), size = 1) +
  geom_line(aes(x = c(360, out360$time), y = c(1, out360$surv[,714]), color = "360d"), size = 1) +
  geom_line(aes(x = c(450, out450$time), y = c(1, out450$surv[,714]), color = "450d"), size = 1) +
  geom_line(aes(x = c(540, out540$time), y = c(1, out540$surv[,714]), color = "540d"), size = 1) +
  scale_color_manual(values = setNames(paletteer_d("ggsci::deep_orange_material")[10:2],
                                       c("0", paste0(c(30,60,90,180,270,360,450,540), "d")))) +
  scale_y_continuous(expand = c(0,0), limits = c(0,1),
                     breaks = seq(0,1,0.2),
                     labels = scales::percent_format(accuracy = 1)) +
  scale_x_continuous(breaks = seq(0,810,90), expand = c(0,0)) +
  labs(x = "Days", y = "Conditional survival probability",
       title = "Predicted conditional survival probability", color = "") +
  theme_classic() +
  theme(legend.position = "right",
        plot.title = element_text(face = "bold", size = 12),
        panel.grid.major = element_line(size = .2, colour = "grey90"),
        panel.grid.minor = element_line(size = .2, colour = "grey90"))


We also can estimate variable importance using proportion of explainable log-likelihood in the Cox regression.

fml = as.formula(paste0("Surv(time = time_start, time2 = time_stop, event = churn) ~ ",
                        paste(features_baseline, collapse = " + "),
                        " + ",
                        paste(features_timedep, collapse = " + ")))

cox_td_rms <- cph(fml, df_train, x=TRUE, y=TRUE)
cox_td_anova_rms <- anova(cox_td_rms)
plot(cox_td_anova_rms, sort = "ascending")

Variable importance, Cox PH with time-dependent variables

Top contributors to model performance here are time-dependent variables for the cumulative number of cancellations and last payment method used by a client.


As far as coefficients and hazard ratios are concerned, they have not such clear interpretation in the Cox regression with time-dependent covariates, so we will not interpret them here.


5 Conclusion


The task of this study project was to predict churn probability using survival analysis. We found out several advantages of using these methods for this kind of business tasks. Firstly, they allow to predict not only the probability of churn over some predefined period of time, but its probability for different time points after the start of the client-company relationship. Secondly, they give information about the form of churn hazard function, which is useful in the analysis of attrition dynamics (e.g. critical points when a company loses its clients). At last, they have a baggage of helpful visual instruments to illustrate all these results.

We studied existing literature on both statistical models and machine learning techniques used for predicting survival distributions. In comparison with other tasks in machine learning, this one has several peculiarities concerning the definition of the target variable (caused by the phenomenon of censoring), its prediction and performance evaluation. As a consequence, “standard” ML models require adaptation for this task, and, as many results show, sometimes they fall behind statistical models in performance. Additional obstacles arise with the use of time-dependent predictors and time-dependent effects: almost none of ML models can work with them, or, at least, there are no enough instruments available for it.

We used a sample of observations from the dataset on KKBox music subscription service given at Kaggle. We used the same definition of churn as was given in competition, but modified the task such that we were able to use survival analysis to solve it.

After preprocessing we obtained data on subscription periods, some of which ended with user’s attrition, or censored otherwise. We divided our sample on train and test sets. As a baseline, we estimated Cox proportional hazards regression. Then, using test sample, we compared the results of this baseline model with random survival forest which has not an assumption of hazards proportionality – and it just a little outperformed Cox PH. Both models gained after addition of data on users’ behavior in their previous subscription periods, although we would not assert that this gain was significant.

Using the results from these models we predict survival curves for users from the test sample and arrange variables by their importance. We estimated Cox PH and RSF models with time-dependent covariates on users’ logs after the start of each subscription period. We showed that new information about clients gives a considerable gain in prediction performance and allows to update their survival curves and predictions in the process of customer-company relationship.

References

Abdul-Fatawu, Majeed. 2020. “Accelerated Failure Time Models.” The Journal of Risk Management and Insurance 24 (2): 12–35.
Bou-Hamad, Imad, Denis Larocque, and Hatem Ben-Ameur. 2011. “A Review of Survival Trees.” Statistics Surveys 5: 44–71.
Breheny, Patrick. n.d.a. “Accelerated Failure Time Models.” The University of Iowa. Accessed December 4, 2021. https://myweb.uiowa.edu/pbreheny/7210/f15/notes.html.
———. n.d.b. “Time-Dependent Coefficients.” The University of Iowa. Accessed December 4, 2021. https://myweb.uiowa.edu/pbreheny/7210/f15/notes.html.
Camilleri, Liberato. 2019. “History of Survival Analysis.” History of Survival Analysis, March.
Emmert-Streib, Frank, and Matthias Dehmer. 2019. “Introduction to Survival Analysis in Practice.” Machine Learning and Knowledge Extraction 1 (September): 1013–38. https://doi.org/10.3390/make1030058.
Harrell, Frank. 2015. Regression Modeling Strategies with Applications to Linear Models, Logistic and Ordinal Regression, and Survival Analysis. Springer International Publishing.
———. 2021. “Rms: Regression Modeling Strategies.” 2021. https://cran.r-project.org/web/packages/rms.
Herndon, James Emmett. 1988. “A Parametric Survival Model Which Generates Monotonic and Non-Monotonic Hazard Functions and Incorporates Time-Dependent Covariables.” PhD thesis, Citeseer.
Kassambara, Alboukadel, Marcin Kosinski, Fabian Scheipl, and Biecek Przemyslaw. 2021. “Survminer.” 2021. https://cran.r-project.org/web/packages/survminer.
Kazmi, Rashid. 2019. “Survival Analysis to Understand Customer Retention.” November 13, 2019. https://towardsdatascience.com/survival-analysis-to-understand-customer-retention-e3724f3f7ea2.
Laine, Thomas, and Eric M Reyes. 2014. “Tutorial: Survival Estimation for Cox Regression Models with Time-Varying Coefficients Using SAS and r.” Journal of Statistical Software 61 (1): 1–23.
Mills, Melinda. 2010. Introducing Survival and Event History Analysis. Sage.
Pedregosa, F., G. Varoquaux, A. Gramfort, V. Michel, B. Thirion, O. Grisel, M. Blondel, et al. 2011. “Scikit-Learn: Machine Learning in Python.” Journal of Machine Learning Research 12: 2825–30.
Pölsterl, Sebastian. 2020. “Scikit-Survival: A Library for Time-to-Event Analysis Built on Top of Scikit-Learn.” Journal of Machine Learning Research 21 (212): 1–6. http://jmlr.org/papers/v21/20-729.html.
Rebora, Paola, Agus Salim, and Marie Reilly. 2014. “Bshazard: A Flexible Tool for Nonparametric Smoothing of the Hazard Function.” R Journal 6 (2).
Rodríguez, Germán. 2007. “Chapter 7. Survival Models.” In Lecture Notes on Generalized Linear Models. https://data.princeton.edu/wws509/notes/.
Royston, Patrick, and Mahesh KB Parmar. 2002. “Flexible Parametric Proportional-Hazards and Proportional-Odds Models for Censored Survival Data, with Application to Prognostic Modelling and Estimation of Treatment Effects.” Statistics in Medicine 21 (15): 2175–97.
Schultz, Lonni R, Edward L Peterson, and Naomi Breslau. 2002. “Graphing Survival Curve Estimates for Time-Dependent Covariates.” International Journal of Methods in Psychiatric Research 11 (2): 68–74.
Simon, Noah, Jerome Friedman, Trevor Hastie, and Rob Tibshirani. 2011. “Regularization Paths for Cox’s Proportional Hazards Model via Coordinate Descent.” Journal of Statistical Software 39 (5): 1.
Sonabend, Raphael Edward Benjamin. 2021. “A Theoretical and Methodological Framework for Machine Learning in Survival Analysis: Enabling Transparent and Accessible Predictive Modelling on Right-Censored Time-to-Event Data.” PhD thesis, UCL (University College London).
Therneau, Terry M. 2021. A Package for Survival Analysis in r. https://CRAN.R-project.org/package=survival.
Therneau, Terry M, and D Watson. 2017. “The Concordance Statistic and the Cox Model.” Department of Health Science Research Mayo Clinic Technical Report 85: 1–18.
Therneau, Terry, Cindy Crowson, and Elizabeth Atkinson. 2017. “Using Time Dependent Covariates and Time Dependent Coefficients in the Cox Model.” Survival Vignettes 2: 3.
Vargas, Rafael. 2018. “Survival Analysis of Mobile Prepaid Customers Using the Weibull Distribution.”
Wang, Ping, Yan Li, and Chandan K Reddy. 2019. “Machine Learning for Survival Analysis: A Survey.” ACM Computing Surveys (CSUR) 51 (6): 1–36.
Yao, Weichi, Halina Frydman, Denis Larocque, and Jeffrey S Simonoff. 2020. “Ensemble Methods for Survival Data with Time-Varying Covariates.” arXiv Preprint arXiv:2006.00567.
Yao, Weichi, Halina Frydman, Denis Larocque, and Jeffrey S. Simonoff. 2021. LTRCforests: Ensemble Methods for Survival Data with Time-Varying Covariates.” 2021. https://cran.r-project.org/web/packages/LTRCforests.
