Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
The following example was shown in an advanced statistics seminar held in tel aviv. The material for the presentation comes from C.M Bishop’s book : Pattern Recognition and Machine Learning by Springer(2006).
One way of separating 2 categories using linear sub spaces of the input space (e.g. planes for 3D inputs, lines for 2D inputs, etc.) is by dimensionality reduction:
y=w’x
if x belongs to a P dimensional input space of real numbers, and w is a P on 1 vector of weights then y is the projection of X to the P-1 sub-plane. We assign x to C1 if y>0 and to C2 otherwise. The group of x‘s that hold y=0 are called the “Decision Boundary”.
Other features of this method are listed in the book, one which is important for the presentation is that each vector on the sub-plane is orthogonal to w. Proof : let x1, x2 be points on the decision boundary separating R1 and R2. Since y=0 for all these points we conclude :
w’x1=w’x2=0
w’(x1-x2)=0
So w is orthogonal to the decision boundary by definition.
Assume with we have two groups following a normal distribution with different means and similar variance co-variance matrices Σ :
Given the mean points M1,M2. A good strategy might be to find w so the projection of the means m1 and m2 will reach maximum separation. By putting the constraint w’w=1 and using lagrange multipliers we find that w is proportional to (M1-M2). The problems begins when the groups have diagonal covariances and the projected values overlap. This can happen even if the two means are fully separated. Fisher (1926) wrote the criterion for this solution, with w proportional to inverse(Σ)(M1-M2).
By looking at the histograms of the projected values we get two histograms of a linear transformation of normal bi-variate variables, hence also a normal distribution. The code is attached, follow it step by step and adjust the device size and sampling parameters to see how Fisher’s criterion for w is superior to the ordinary optimization.
# (1) sampling assuming mvnorm with same Sigma { rm(list=ls()) mu.1 <- c(2,-1) mu.2 <- c(2,5) rho <- 0.8 sigma.1 <- 1 sigma.2 <- 3 Sigma <- matrix(c(sigma.1^2 ,rho*sigma.1*sigma.2 ,rho*sigma.1*sigma.2,sigma.2^2),byrow=T,nrow=2) N <- 100 # multivariate normal sampling X1 <- MASS::mvrnorm(N,mu=mu.1,Sigma=Sigma) X2 <- MASS::mvrnorm(N,mu=mu.2,Sigma=Sigma) # make a data frame X <- data.frame(cbind(rep(c(4,2),each=N),rbind(X1,X2))) names(X) <- c("group","X1","X2") means <- matrix(c(tapply(X$X1,X$group,mean),tapply(X$X2,X$group,mean)),2,2) means <- data.frame(X1=means[,1],X2=means[,2],row.names=c(2,4)) A <- matrix(NA,nrow=nrow(X),ncol=2) A[,1] <- as.numeric(X$X1) ;A[,2] <- as.numeric(X$X2)} # (2) plot the sample { PLOT <- function(main) { layout(matrix(c(1,1,2,3),byrow=T,ncol=2),heights=c(0.2,0.8)) par(mar=rep(0,4)) plot.new() text(0.5,0.5,main,cex=3) par(mar=c(4,4,0,0)) plot(X2~X1,data=X,pch=21,bg=1 ,type="n" ,xlab=expression(x[1]),ylab=expression(x[2])) axx <- par("xaxp") axx <- seq(axx[1],axx[2],length.out=axx[3]+1) axy <- par("yaxp") axy <- seq(axy[1],axy[2],length.out=axy[3]+1) abline(h=axy,v=axx,lty=5,col=gray(0.8)) points(X2~X1,pch=21,bg=group,col=group,data=X)} PLOT(main="Simple Linear Classification")} # (3) show means and mean line { points(means,col=1,pch=3,cex=3,lwd=3) lines(means,lwd=3,col=1)} # (4) calculate midpoint and orthogonal line { mid.point <- apply(means,2,mean) points(mid.point[2]~mid.point[1],col=1,pch=4,cex=2,lwd=4) H <- c(-1,1) m <- H %*% as.matrix(means) m <- m[2]/m[1] inv.m <- solve(m,-1) arrows( x0=mid.point[1],x1=mid.point[1]+2 ,y0=mid.point[2],y1=mid.point[2]+2*inv.m ,col=1,lwd=3)} # (5) Simple Linear Optimization { SLO <- function() { lambda <- sqrt((as.matrix(means[1,]-means[2,])) %*% t(as.matrix(means[1,]-means[2,])))/2 w <- (means[2,]-means[1,])/(2*lambda) w <- matrix(as.numeric(w)) wX <- A %*% w wX <- data.frame(group=X$group,wX=wX) h4 <- hist(wX$wX[wX$group==4],plot=F) h2 <- hist(wX$wX[wX$group==2],plot=F) breaks <- sort(union(h4$breaks,h2$breaks)) counts4 <- c(rep(0,length(breaks)-length(h4$counts)-1),h4$counts) counts2 <- c(h2$counts,rep(0,length(breaks)-length(h2$counts)-1)) mids <- sort(union(h4$mids,h2$mids)) h2$breaks <- h4$breaks <- breaks h2$mids <- h4$mids <- mids h4$counts <- counts4 plot(h2,border=0,main="" ,xlab="Projection Values",ylim=c(0,max(counts4,counts2))) rect( # blue group xleft=breaks[-length(breaks)] ,ybottom=0 ,xright=breaks[-1] ,ytop=counts4 ,col=4,density=20,angle=45) rect( # red group xleft=breaks[-length(breaks)] ,ybottom=0 ,xright=breaks[-1] ,ytop=counts2 ,col=2,density=20,angle=-45) } # close on SLO SLO()} # (6) Fisher's Criterion { means <- matrix(c(tapply(X$X1,X$group,mean),tapply(X$X2,X$group,mean)),2,2) b <- as.numeric(solve(cov(A)) %*% (means[1,]-means[2,])) b0 <- 0.5*((t(means[1,]-means[2,]) %*% solve(cov(A))) %*% (means[1,]-means[2,])) Lx <- (A %*% b)+b0*rep(1,N)} # (7) re-plot the sample { PLOT(main="Fisher's Criterion")} # (8) plot Fisher's Criterion { x <- par("xaxp") x <- seq(100*x[1],100*x[2],length.out=2) b1 <- b[1] b2 <- b[2] y <- (-b0-b1*x)/b2 lines(y~x,lwd=3)} # (9) plot Projection Histogram { FC <- function() { L4 <- hist(Lx[X$group==4],plot=F) L2 <- hist(Lx[X$group==2],plot=F) Lbreaks <- sort(union(L4$breaks,L2$breaks)) Lcounts4 <- c(rep(0,length(Lbreaks)-length(L4$counts)-1),L4$counts) Lcounts2 <- c(L2$counts,rep(0,length(Lbreaks)-length(L2$counts)-1)) Lmids <- sort(union(L4$mids,L2$mids)) L2$breaks <- L4$breaks <- Lbreaks L2$mids <- L4$mids <- Lmids L4$counts <- Lcounts4 plot(L2,border=0,main="" ,xlab="Projection Values",ylim=c(0,max(Lcounts4,Lcounts2))) rect( # blue group xleft=Lbreaks[-length(Lbreaks)] ,ybottom=0 ,xright=Lbreaks[-1] ,ytop=Lcounts4 ,col=4,density=20,angle=45) rect( # red group xleft=Lbreaks[-length(Lbreaks)] ,ybottom=0 ,xright=Lbreaks[-1] ,ytop=Lcounts2 ,col=2,density=20,angle=-45) } # close on FC FC()}
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.