1. Data

다음과 같은 데이터가 있다고 가정하자.

각 행은 고유한 id를 가진 개별 개체이다.

나머지 3개의 변수는 continent_large -> continent_small -> country 순으로 계층적인 구조를 가지고 있다.

head(data)
##   id continent_large continent_small     country
## 1  1            Asia       East Asia South Korea
## 2  2            Asia       East Asia South Korea
## 3  3            Asia       East Asia South Korea
## 4  4            Asia       East Asia South Korea
## 5  5            Asia       East Asia South Korea
## 6  6            Asia       East Asia South Korea
tail(data)
##          id continent_large continent_small country
## 29044 29044          Europe Southern Europe  Greece
## 29045 29045          Europe Southern Europe  Greece
## 29046 29046          Europe Southern Europe  Greece
## 29047 29047          Europe Southern Europe  Greece
## 29048 29048          Europe Southern Europe  Greece
## 29049 29049          Europe Southern Europe  Greece


2. Plotly

plotly 패키지에서 제공하는 함수로 편하게 sankey diagram, sunburst chart를 만들 수 있다.

하지만 위와 같은 형태로 수집된 데이터는 plotly에 맞춰진 vector들을 구성하기 어렵다.

library(plotly)

fig <- plot_ly(
    type = "sankey",
    orientation = "h",

    node = list(
      label = c("A1", "A2", "B1", "B2", "C1", "C2"),
      color = c("blue", "blue", "blue", "blue", "blue", "blue"),
      pad = 15,
      thickness = 20,
      line = list(
        color = "black",
        width = 0.5
      )
    ),

    link = list(
      source = c(0,1,0,2,3,3),
      target = c(2,3,3,4,4,5),
      value =  c(8,4,2,8,4,2)
    )
  )
fig <- fig %>% layout(
    title = "Basic Sankey Diagram",
    font = list(
      size = 10
    )
)

fig
fig <- plot_ly(
  labels = c("Eve", "Cain", "Seth", "Enos", "Noam", "Abel", "Awan", "Enoch", "Azura"),
  parents = c("", "Eve", "Eve", "Seth", "Seth", "Eve", "Eve", "Awan", "Eve"),
  values = c(10, 14, 12, 10, 2, 6, 6, 4, 4),
  type = 'sunburst'
)

fig


3. Functions

sunburst <- function(...) UseMethod('sunburst')

sunburst.data.frame <- function(data, vars, center) {
  # hierarchical frequency table 
  recurse <- function(data, vars, parents = NULL, result = list()) {
    if (length(vars) == 0) {
      return(result)
    } else {
      v <- vars[1]
      if (length(parents) == 0) {
        tab <- list(table(data[,v]))
        tab <- list(tab); names(tab) <- v
      } else {
        dat_list <- split(data, data[,parents])
        tab <- lapply(dat_list, function(d,v) table(d[,v]), v=v)
        tab <- list(tab)
        names(tab) <- v
      }
      parents <- c(parents, v)
      if (length(tab) >= 1) {
        result <- append(result, tab)  
      }
    }
    recurse(data, vars[-1], parents, result = result)
  }
  tab_list <- recurse(data, vars)
  for (i in seq_along(tab_list)) {
    tab_list[[i]] <- tab_list[[i]][vapply(tab_list[[i]], length, 1) >= 1]
  }
  
  # table to data.frame for plotly form 
  result_list <- list()
  for (i in seq_along(tab_list)) {
    if (i == 1) {
      labels <- unlist(lapply(tab_list[[i]], names))
      ids <- labels
      parents <- rep('', length(ids))
      values <- unname(unlist(tab_list[[i]]))
      
      if (!missing(center)) {
        labels <- c(center, labels)
        ids <- c(center, ids)
        parents <- c('', rep(center, length(parents)))
        values <- c(sum(values), values)
      }
      
      df <- data.frame(
        ids=ids, labels=labels, parents=parents, values=values, 
        stringsAsFactors = F
      )
    } else {
      labels <- lapply(tab_list[[i]], names)
      ids <- unname(unlist(Map(paste, names(labels), labels, sep = '-')))
      ids <- gsub('\\.', '-', ids)
      labels <- unname(unlist(labels))
      
      parents <- names(tab_list[[i]])
      parents <- unlist(Map(rep, parents, sapply(tab_list[[i]], length)))
      parents <- unname(parents)
      parents <- gsub('\\.', '-', parents)
      
      values <- unname(unlist(tab_list[[i]]))
      
      df <- data.frame(
        ids=ids, labels=labels, parents=parents, values=values, 
        stringsAsFactors = F
      )
    }
    result_list[[i]] <- df
  }
  result <- do.call(rbind, result_list)
  result <- structure(result, class = 'sunburst_df')
  return(result)
}

sunburst.sunburst_df <- function(
  object, branchvalues = c('remainder', 'total')
) {
  branchvalues <- match.arg(branchvalues)
  
  p <- plotly::plot_ly(
    ids = object$ids,
    labels = object$labels,
    parents = object$parents,
    values = object$values,
    branchvalues = branchvalues, 
    type = 'sunburst',
  )
  return(p)
}

sdf <- sunburst(data, c('continent_large', 'continent_small', 'country'))
sunburst(sdf)

sunburst라는 generic 함수를 만들고 data.frame, sunburst_df class에 대응되는 함수를 따로 만들었다.

sunburst.data.frame 함수는 labels, parents, values 벡터를 만들주고 sunburst_df라는 class의 결과값을 리턴한다.

sunburst.sunburst_df 함수는 sunburst_df class object와 plotly 함수를 활용하여 그래프를 그린다.

sankey <- function(...) UseMethod('sankey')

sankey.data.frame <- function(data, vars) {
  # variables order
  len_vars <- length(vars)
  idx1 <- 1:(len_vars-1)
  idx2 <- idx1+1
  idx_list <- Map(function(x,y) c(x,y), idx1, idx2)
  vars_list <- lapply(idx_list, function(idx, x) x[idx], x = vars)
  
  # duplicated category
  vars_unique <- lapply(data[, vars], unique)
  is_dup <- list()
  for (i in vars) {
    v <- vars_unique[[i]]
    comp <- vars_unique[-which(vars == i)]
    isu <- sapply(comp, function(x, v) any(x %in% v), v = v)
    if (any(isu)) {
      is_dup[[i]] <- names(isu)[isu]
    }
  }
  is_dup <- unique(unlist(is_dup))
  
  # frequency table
  freqtab <- function(d, v, is_dup) {
    dsub <- d[,v]
    result <- aggregate(rep(1, nrow(dsub)), dsub, sum)
    names(result) <- c('source', 'target', 'value')
    result$source <- as.character(result$source)
    result$target <- as.character(result$target)
    result
  }
  
  # if duplicated category, add variable name
  if (length(is_dup) >= 1) {
    for (i in is_dup) data[[i]] <- paste(i, data[[i]], sep = '.')
  }
  tab_list <- lapply(vars_list, freqtab, d = data)
  
  # sankey data for plotly
  sankey_dat <- do.call(rbind, tab_list)
  sankey_dat$label_source <- sankey_dat$source
  sankey_dat$label_target <- sankey_dat$target
  
  label <- unique(c(sankey_dat$source, sankey_dat$target))
  sankey_dat$source <- as.numeric(factor(sankey_dat$source, label))-1
  sankey_dat$target <- as.numeric(factor(sankey_dat$target, label))-1
  
  result <- list(sankey_dat = sankey_dat, label = label, tab_list = tab_list)
  result <- structure(result, class = 'sankey_df')
  return(result)
}

sankey.sankey_df <- function(object, palette = 'Set1', tables = FALSE, title = '') {
  # info
  sankey_dat <- object$sankey_dat
  label <- object$label
  tab_list <- object$tab_list
  
  # color
  color <- suppressWarnings(RColorBrewer::brewer.pal(length(label), palette))
  color <- color[1:length(label)]
  
  is_na_col <- is.na(color)
  while (any(is_na_col)) {
    na_col_len <- sum(is_na_col)
    suppressWarnings(
      color[is_na_col] <- RColorBrewer::brewer.pal(na_col_len, palette)
    )
    is_na_col <- is.na(color)
  }
  
  # plotly
  p <- plotly::plot_ly(
    type = "sankey",
    orientation = "h",
    
    node = list(
      label = label,
      color = color,
      pad = 15,
      thickness = 20,
      line = list(
        color = "black",
        width = 0.5
      )
    ),
    link = list(
      source = sankey_dat$source,
      target = sankey_dat$target,
      value =  sankey_dat$value
    )
  )
  p <- plotly::layout(p = p, title = title, font = list(size = 10))
  
  if (tables) {
    return(list(p = p, tab_list = tab_list))
  }
  return(p)
}

sdf <- sankey(data, c('continent_large', 'continent_small', 'country'))
sankey(sdf)

sunburst와 마찬가지로 sankey라는 generic 함수를 만들고 data.frame, sankey_df class에 대응되는 함수를 따로 만들었다.

sankey.data.frame 함수는 source, target, value 벡터를 만들주고 sankey_df라는 class의 결과값을 리턴한다.

sankey.sunburst_df 함수는 sankey_df class object와 plotly 함수를 활용하여 그래프를 그린다.