#setwd("qiulab/cdg-data/")
genes <- read.csv("cdgSNPmatrix-Jinyuan.csv3", header = F, row.names = 1)
cdg <- read.csv("cdgTable.csv2", header = T, na.strings = T)
header <- read.csv("header.csv", sep = ",", header = F)
d.snp <- dist(genes, method = "manhattan") #manhattan b/c 0/1
hc.snp <- hclust(d.snp)
plot(hc.snp)

grp.snps <- lapply(2:10, function(x) {cutree(hc.snp, k = x)}) #make groups
cdg.mean <- tapply(cdg$logcdg, cdg$strains, mean)
eta <- 0.1 #learning rate
w <- list() #for 9 lists of weights
b <- runif(9) #bias for each group
training <- 20
target <- 10
snp <- list()
target.snp <- list()
weights <- list()
accuracy <- matrix(NA, nrow = 1000, ncol = 9)
a <- matrix(NA, nrow = training, ncol = 9)
y <- matrix(NA, nrow = training, ncol = 9)
e <- matrix(NA, nrow = training, ncol = 9)
for(l in 1:length(grp.snps)){
group <- length(table(grp.snps[[l]]))
snp[[l]] <- matrix(0,nrow = 30, ncol = group)
for(state in 1:group){
snp[[l]] <- t(genes[sample(nrow(genes[which(grp.snps[[l]]==state),]),group),])
}
rownames(snp[[l]]) <- as.character(as.matrix(header[-1]))
snp[[l]] <- snp[[l]][sample(training),] #change training model number
t <- as.matrix(ifelse(cdg.mean[row.names(snp[[1]])] < -1, 0,1 ))
snp[[l]] <- cbind(t, snp[[l]])
w[[l]] <- runif(group, 1e-3, 1e-2) #random weights for each list
weights[[l]] <- matrix(0,nrow =1000, ncol = state)
target.snp[[l]] <- matrix(0,nrow = target, ncol = c(group+1))
target.snp[[l]] <- snp[[l]][sample(target),]
for(epoch in 1:1000){
a[,l] <- snp[[l]][,2:c(group+1)] %*% w[[l]] #activation
y[,l] <- 1/(1+exp(-a[,l]-b[l])) #output
e[,l] <- snp[[l]][,1] - y[,l] #backpropogation
w[[l]] <- w[[l]] - eta * -colSums(snp[[l]][,2:c(group+1)] * e[,l]) #update weights
b[l] <- b[l] - eta * sum(-e[,l]) #update bias
weights[[l]][epoch,] <- w[[l]]
accuracy[epoch,l] <- length(which(round(y[,l]) == snp[[l]][,1]))/nrow(snp[[l]])
}
}
target.y <- list()
target.acc <- numeric()
for(i in 1:9){
target.y[[i]] <- 1/(1+exp(-(target.snp[[i]][,2:c(i+2)] %*% w[[i]])-b[i]))
target.acc <- c(target.acc, length(which(round(target.y[[i]]) == target.snp[[i]][,1]))/target)
}
plot(NA, xlim = c(2,10), ylim = c(0, 1), ylab = "accuracy", xlab = "group" ,main = "Training vs. Target")
lines(x = 2:10, y = target.acc, col=1, lwd=2)
lines(x = 2:10, y = accuracy[1000,1:9], col=2, lwd=2)
legend("bottom", legend = c("Training","Target"),text.col = c(2,1), horiz = T)

plot(accuracy[1000,1:9], xlim = c(2,10),type = "p", main = "Accuracy of prediction", xlab = "group", col="blue")

plot(NA, xlim = c(1,1000), ylim = c(-10,10), main = "Weight epoch throughout groups", xlab = "epoch", ylab = "weights range")
clr <- rainbow(9)
for(i in 1:9){
for(k in 1:ncol(weights[[i]])){
lines(x = 1:1000, y = weights[[i]][,k], col = clr[i])
}
}
