#setwd("qiulab/cdg-data/")
genes <- read.csv("cdgSNPmatrix-Jinyuan.csv3", header = F, row.names = 1)
cdg <- read.csv("cdgTable.csv2", header = T, na.strings = T)
header <- read.csv("header.csv", sep = ",", header = F)

d.snp <- dist(genes, method = "manhattan") #manhattan b/c 0/1
hc.snp <- hclust(d.snp)
plot(hc.snp)

grp.snps <- lapply(2:10, function(x) {cutree(hc.snp, k = x)}) #make groups
cdg.mean <- tapply(cdg$logcdg, cdg$strains, mean)

eta <- 0.1 #learning rate
w <- list() #for 9 lists of weights
b <- runif(9) #bias for each group

training <- 20
target <- 10
snp <- list()
target.snp <- list()
weights <- list()
accuracy <- matrix(NA, nrow = 1000, ncol = 9)
a <- matrix(NA, nrow = training, ncol = 9)
y <- matrix(NA, nrow = training, ncol = 9)
e <- matrix(NA, nrow = training, ncol = 9)

for(l in 1:length(grp.snps)){
  group <- length(table(grp.snps[[l]]))  
  snp[[l]] <- matrix(0,nrow = 30, ncol = group)
  for(state in 1:group){
    snp[[l]] <- t(genes[sample(nrow(genes[which(grp.snps[[l]]==state),]),group),])
  }
  rownames(snp[[l]]) <- as.character(as.matrix(header[-1]))
  snp[[l]] <- snp[[l]][sample(training),] #change training model number
  t <- as.matrix(ifelse(cdg.mean[row.names(snp[[1]])] < -1, 0,1 ))
  snp[[l]] <- cbind(t, snp[[l]])
  w[[l]] <- runif(group, 1e-3, 1e-2) #random weights for each list
  weights[[l]] <- matrix(0,nrow =1000, ncol = state)
  target.snp[[l]] <- matrix(0,nrow = target, ncol = c(group+1))
  target.snp[[l]] <- snp[[l]][sample(target),]
  for(epoch in 1:1000){
    a[,l] <- snp[[l]][,2:c(group+1)] %*% w[[l]] #activation
    y[,l] <- 1/(1+exp(-a[,l]-b[l])) #output
    e[,l] <- snp[[l]][,1] - y[,l] #backpropogation
    w[[l]] <- w[[l]] - eta * -colSums(snp[[l]][,2:c(group+1)] * e[,l]) #update weights
    b[l] <- b[l] - eta * sum(-e[,l]) #update bias
    weights[[l]][epoch,] <- w[[l]]
    accuracy[epoch,l] <- length(which(round(y[,l]) == snp[[l]][,1]))/nrow(snp[[l]])
  }
}

target.y <- list()
target.acc <- numeric()
for(i in 1:9){
  target.y[[i]] <- 1/(1+exp(-(target.snp[[i]][,2:c(i+2)] %*% w[[i]])-b[i]))
  target.acc <- c(target.acc, length(which(round(target.y[[i]]) == target.snp[[i]][,1]))/target)
}

plot(NA, xlim = c(2,10), ylim = c(0, 1), ylab = "accuracy", xlab = "group" ,main = "Training vs. Target")
lines(x = 2:10, y = target.acc, col=1, lwd=2)
lines(x = 2:10, y = accuracy[1000,1:9], col=2, lwd=2)
legend("bottom", legend = c("Training","Target"),text.col = c(2,1), horiz = T)

plot(accuracy[1000,1:9], xlim = c(2,10),type = "p", main = "Accuracy of prediction", xlab = "group", col="blue")

plot(NA, xlim = c(1,1000), ylim = c(-10,10), main = "Weight epoch throughout groups", xlab = "epoch", ylab = "weights range")
clr <- rainbow(9)
for(i in 1:9){
  for(k in 1:ncol(weights[[i]])){
    lines(x = 1:1000, y = weights[[i]][,k], col = clr[i])
  }
}