# Generate two multivariate normal classes in two dimensions
library("MASS")
library("e1071")
## Warning: package 'e1071' was built under R version 4.0.5
# for matlib to load rgl is needed
# install.packages("rgl", repos="http://R-Forge.R-project.org")
# other libraries are needed
# install.packages("crosstalk", dependencies = TRUE)
# install.packages("manipulateWidget", dependencies = TRUE)

library("matlib")
## Warning: package 'matlib' was built under R version 4.0.5
# this library has functions to compute inverse of the matrix
# inv()

# Linearly separable case with unity covariance matrix
# This is nearest mean classifier
sigma <- matrix(c(1,0,0,1),2,2)
m1=c(0,0)
m2=c(5,5)
# class 1
x1=mvrnorm(n=1000, m1, sigma)
y1=rep(0,1000)

# class 2
x2=mvrnorm(n=1000, m2, sigma)
y2=rep(1,1000)

y_class=factor(c(y1,y2))

# axis limits
xr=c(-10,20)
yr=c(-5,15)
plot(x1[,1],x1[,2],type = 'p', col="blue", xlim=xr, ylim=yr)
points(x2[,1],x2[,2],type = 'p', col="red")
points(m1[1],m1[2],col="black",type="p")
points(m2[1],m2[2],col="black",type="p")

# Compute classifier
# Estimate the common covariance matrix
c1=cov(x1)
c2=cov(x2)

c=0.5*(c1+c2)
c
##           [,1]      [,2]
## [1,] 0.9834649 0.0247577
## [2,] 0.0247577 0.9891959
m=m1-m2
m
## [1] -5 -5
# compute LDC weights
w=t(m)%*%inv(c)
# This is nearest mean classifier
# the w equal to m and it is an optimal direction to project 
# the class points
# w is a row vector; for dot product it should be multiplied
#   by the column vector. The x1 should be transposed.

# lets normalize vector w
wn=w/norm(w)

# lets project the points on the direction wn
p1=wn%*%t(x1)
p2=wn%*%t(x2)

# p1 and p2 are projected class points onto direction m
# pot their histogram
hist(c(p1,p2) )

# b in this scenario is a point in the middle of the vector
# that is connecting m1 and m2 it is (2.5, 2.5)
# lets project it onto wn and this will be our b
b=wn%*%c(2.5,2.5)

# b value is -4.9

# classification rule is 
# compute z=wn%*%x-b compare it with threshold T if 
# if z< 0 then class 1 otherwise class2 

p1=( wn%*%t(x1) )-c(b)
p2=( wn%*%t(x2) )-c(b)

hist(c(p1,p2) )

p=c(p1,p2)
# This is artificially created to show 
# meaning of confuson table and ROC
# Classification by applying threshold
# if w%*%x > T then class1 , if < T then class2
# T range

N=length(y_class)

Tr=seq(from=-20, by=1, to=20)
TPR=rep(0,length(Tr))
FPR=rep(0,length(Tr))
i=0
for (T in Tr )
{
  r=1*(p<T)
  rf=factor(r,levels=c(0,1))
  # compute confusion table
  t=table(y_class,rf)
  print(t)

  # True positives and True Negatives
  # in other words correct classifications
  # are on diagonal of the confusion table
  correctc=sum(diag(t))
  print(correctc)
  
  # false classifications are 
  # off diagonal 
  falsec=t[1,2]+t[2,1]
  print(falsec)
# We want a choice of threshold 
# to maximize correctc and minimize falsec RATE

  tpr=correctc/N
  fpr=falsec/N
  i=i+1
  TPR[i]=tpr
  FPR[i]=fpr
}
##        rf
## y_class    0    1
##       0 1000    0
##       1 1000    0
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0 1000    0
##       1 1000    0
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0 1000    0
##       1 1000    0
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0 1000    0
##       1 1000    0
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0 1000    0
##       1 1000    0
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0 1000    0
##       1 1000    0
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0 1000    0
##       1 1000    0
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0 1000    0
##       1 1000    0
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0 1000    0
##       1 1000    0
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0 1000    0
##       1 1000    0
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0 1000    0
##       1 1000    0
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0 1000    0
##       1  998    2
## [1] 1002
## [1] 998
##        rf
## y_class    0    1
##       0 1000    0
##       1  988   12
## [1] 1012
## [1] 988
##        rf
## y_class    0    1
##       0 1000    0
##       1  921   79
## [1] 1079
## [1] 921
##        rf
## y_class    0    1
##       0 1000    0
##       1  757  243
## [1] 1243
## [1] 757
##        rf
## y_class    0    1
##       0 1000    0
##       1  503  497
## [1] 1497
## [1] 503
##        rf
## y_class    0    1
##       0 1000    0
##       1  250  750
## [1] 1750
## [1] 250
##        rf
## y_class    0    1
##       0 1000    0
##       1   84  916
## [1] 1916
## [1] 84
##        rf
## y_class    0    1
##       0 1000    0
##       1   25  975
## [1] 1975
## [1] 25
##        rf
## y_class    0    1
##       0 1000    0
##       1    3  997
## [1] 1997
## [1] 3
##        rf
## y_class    0    1
##       0  999    1
##       1    0 1000
## [1] 1999
## [1] 1
##        rf
## y_class    0    1
##       0  997    3
##       1    0 1000
## [1] 1997
## [1] 3
##        rf
## y_class    0    1
##       0  985   15
##       1    0 1000
## [1] 1985
## [1] 15
##        rf
## y_class    0    1
##       0  919   81
##       1    0 1000
## [1] 1919
## [1] 81
##        rf
## y_class    0    1
##       0  751  249
##       1    0 1000
## [1] 1751
## [1] 249
##        rf
## y_class    0    1
##       0  497  503
##       1    0 1000
## [1] 1497
## [1] 503
##        rf
## y_class    0    1
##       0  227  773
##       1    0 1000
## [1] 1227
## [1] 773
##        rf
## y_class    0    1
##       0   74  926
##       1    0 1000
## [1] 1074
## [1] 926
##        rf
## y_class    0    1
##       0   17  983
##       1    0 1000
## [1] 1017
## [1] 983
##        rf
## y_class    0    1
##       0    0 1000
##       1    0 1000
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0    0 1000
##       1    0 1000
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0    0 1000
##       1    0 1000
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0    0 1000
##       1    0 1000
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0    0 1000
##       1    0 1000
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0    0 1000
##       1    0 1000
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0    0 1000
##       1    0 1000
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0    0 1000
##       1    0 1000
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0    0 1000
##       1    0 1000
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0    0 1000
##       1    0 1000
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0    0 1000
##       1    0 1000
## [1] 1000
## [1] 1000
##        rf
## y_class    0    1
##       0    0 1000
##       1    0 1000
## [1] 1000
## [1] 1000
plot(FPR,TPR,type="p", xlim=c(0,1), ylim=c(0,1))
text(FPR,TPR,labels=paste(Tr), adj=c(0.1,0.3))

### LETS make data nonseparable


# Linearly separable case with common covariance matrix
sigma <- matrix(c(10,3,3,2),2,2)
m1=c(0,0)
m2=c(10,10)
# class 1
x1=mvrnorm(n=1000, m1, sigma)


# class 2
x2=mvrnorm(n=1000, m2, sigma)

xr=c(-10,20)
yr=c(-5,15)
plot(x1[,1],x1[,2],type = 'p', col="blue", xlim=xr, ylim=yr)
points(x2[,1],x2[,2],type = 'p', col="red")
points(m1[1],m1[2],col="black",type="p")
points(m2[1],m2[2],col="black",type="p")

# Estimate the common covariance matrix
c1=cov(x1)
c2=cov(x2)
c1
##          [,1]     [,2]
## [1,] 9.441700 2.787338
## [2,] 2.787338 1.924686
c2
##           [,1]     [,2]
## [1,] 10.001305 3.251931
## [2,]  3.251931 2.099847
c=0.5*(c1+c2)
m=m1-m2
# compute LDC weights
w=m%*%inv(c)

x <- as.matrix(rbind(x1,x2))
y <- c(rep(1, nrow(x1)), rep(-1,nrow(x2)))

dat = data.frame(x, y = as.factor(y))
svmfit = svm(y ~ ., data = dat, kernel = "linear", cost = 10, scale = FALSE)
print(svmfit)
## 
## Call:
## svm(formula = y ~ ., data = dat, kernel = "linear", cost = 10, scale = FALSE)
## 
## 
## Parameters:
##    SVM-Type:  C-classification 
##  SVM-Kernel:  linear 
##        cost:  10 
## 
## Number of Support Vectors:  3
plot(svmfit, dat)