Key reference: McKay (2003). Algorithm 39.5 (page 478)
data(iris)
library(nnet)
iris.2sp <- iris[which(iris[,5]!='setosa'),c(1,3,5)]
# if want a sample for cross-validation:
# t.iris <- iris.2sp[sample(nrow(iris.2sp), 75),]
x <- as.matrix(iris.2sp[,1:2]) # N=100 individuals
t <- class.ind(iris.2sp[,3])[,2] # 0/1 class (not "one hot" coding)
neuron <- function(training.mat, target.vec, max.epoch, learn.rate) {
out <- data.frame(epoch=integer(), biase=numeric(), w1=numeric(), w2=numeric());
x <- training.mat;
w <- runif(n = 3, min = 1e-4, max = 1e-3) # assign small initial random weights (w0 is bias)
N <- nrow(training.mat);
for (i in 1:max.epoch) {
g <- c(0,0,0); # initialize derivative/gradient for two wts and a bias
error <- 0;
for (j in 1:N) { # batch learning (p.476 Box)
a <- x[j,1] * w[2] + x[j,2] * w[3] + w[1]; # Input/activation for each individual flower
y <- 1/(1+exp(-1*a)); # Output/activity: sigmoid (with logistic) function
e <- target.vec[j] - y; # error
g <- g - e * c(1,x[j,]); # descent gradient
error <- error + e;
}
out <- rbind(out, c(epoch=i, biase=w[1], w1=w[2], w2=w[3]));
w <- w - learn.rate * g; # back-prop error to update weights and biase
}
out;
}
# plot diagnostics
out <- neuron(training.mat = x, target.vec = t, max.epoch = 5000, learn.rate = 0.1)
# contours:
bias <- out[nrow(out),2];
w1 <- out[nrow(out),3];
w2 <- out[nrow(out),4];
plot(x[,1], x[,2], col=t+1, xlab="Sepal Length", ylab="Petal Length", pch=16, las=1)
legend(5,6.5, c("versicolor", "virginica"), pch=16, col=c(1,2), cex=0.75)
curve((-bias-w1*x)/w2, 5, 8, add=T) # a=0
curve((1-bias-w1*x)/w2, 5, 8, add=T, lty=2) # a=1
curve((-1-bias-w1*x)/w2, 5, 8, add=T, lty=2) # a=-1
# evolution of w & b
plot(out[,1], out[,2], type="l", ylim=range(out[,2:4]), log="x", xlab="epoch", ylab="weights & bias", las=1)
lines(out[,1], out[,3], col=2, type="l")
lines(out[,1], out[,4], col=3, type="l")
legend(1,100, c("bias", "w1", "w2"), col=1:3, lty=1, cex=0.75)
# predictions
t.pred <- round(1/(1+exp(-bias-w1*x[,1]-w2*x[,2]))) # accuracy
missed <- sum(abs(t.pred -t))/100
missed # out of 100
## [1] 0.12
plot(x[,1], x[,2], col=t+1, xlab="Sepal Length", ylab="Petal Length", pch=16, las=1)
curve((-bias-w1*x)/w2, 5, 8, add=T) # a=0
text(x[,1], x[,2], t.pred, pos=4, col=t.pred+1, cex=0.8)