Site icon R-bloggers

Fast matrix inversion

[This article was first published on Statistic on aiR, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
Very similar to what has been done to create a function to perform fast multiplication of large matrices using the Strassen algorithm (see previous post), now we write the functions to quickly calculate the inverse of a matrix.

To avoid rewriting pages and pages of comments and formulas, as I did for matrix multiplication, this time I’ll show you directly the code of the function (the reasoning behind it is quite similar). Please, copy and paste all the code in an external editor to see it properly.

Function strassenInv(A)


strassenInv <- function(A){

 div4 <- function(A, r){
  A <- list(A)
  A11 <- A[[1]][1:(r/2),1:(r/2)]
  A12 <- A[[1]][1:(r/2),(r/2+1):r]
  A21 <- A[[1]][(r/2+1):r,1:(r/2)]
  A22 <- A[[1]][(r/2+1):r,(r/2+1):r]
  A <- list(X11=A11, X12=A12, X21=A21, X22=A22)
  return(A)
 }

        if (nrow(A) != ncol(A)) 
          { stop("only square matrices can be inverted") }

 is.wholenumber <-
     function(x, tol = .Machine$double.eps^0.5)  abs(x - round(x)) < tol

 if ( (is.wholenumber(log(nrow(A), 2)) != TRUE) || (is.wholenumber(log(ncol(A), 2)) != TRUE) )
   { stop("only square matrices of dimension 2^k * 2^k can be inverted with Strassen method") }

 A <- div4(A, dim(A)[1])

 R1 <- solve(A$X11)
 R2 <- A$X21 %*% R1
 R3 <- R1 %*% A$X12
 R4 <- A$X21 %*% R3
 R5 <- R4 - A$X22
 R6 <- solve(R5)
 C12 <- R3 %*% R6
 C21 <- R6 %*% R2
 R7 <- R3 %*% C21
 C11 <- R1 - R7
 C22 <- -R6
 
 C <- rbind(cbind(C11,C12), cbind(C21,C22))

 return(C)
}



Function strassenInv2(A)


strassenInv2 <- function(A){

 div4 <- function(A, r){
  A <- list(A)
  A11 <- A[[1]][1:(r/2),1:(r/2)]
  A12 <- A[[1]][1:(r/2),(r/2+1):r]
  A21 <- A[[1]][(r/2+1):r,1:(r/2)]
  A22 <- A[[1]][(r/2+1):r,(r/2+1):r]
  A <- list(X11=A11, X12=A12, X21=A21, X22=A22)
  return(A)
 }

 strassen <- function(A, B){
  A <- div4(A, dim(A)[1])
  B <- div4(B, dim(B)[1])
  M1 <- (A$X11+A$X22) %*% (B$X11+B$X22)
  M2 <- (A$X21+A$X22) %*% B$X11
  M3 <- A$X11 %*% (B$X12-B$X22)
  M4 <- A$X22 %*% (B$X21-B$X11)
  M5 <- (A$X11+A$X12) %*% B$X22
  M6 <- (A$X21-A$X11) %*% (B$X11+B$X12)
  M7 <- (A$X12-A$X22) %*% (B$X21+B$X22)

  C11 <- M1+M4-M5+M7
  C12 <- M3+M5
  C21 <- M2+M4
  C22 <- M1-M2+M3+M6
 
  C <- rbind(cbind(C11,C12), cbind(C21,C22))
  return(C)
 }

        if (nrow(A) != ncol(A)) 
          { stop("only square matrices can be inverted") }

 is.wholenumber <-
     function(x, tol = .Machine$double.eps^0.5)  abs(x - round(x)) < tol

 if ( (is.wholenumber(log(nrow(A), 2)) != TRUE) || (is.wholenumber(log(ncol(A), 2)) != TRUE) )
   { stop("only square matrices of dimension 2^k * 2^k can be inverted with Strassen method") }

 A <- div4(A, dim(A)[1])

 R1 <- strassenInv(A$X11)
 R2 <- strassen(A$X21 , R1)
 R3 <- strassen(R1 , A$X12)
 R4 <- strassen(A$X21 , R3)
 R5 <- R4 - A$X22
 R6 <- strassenInv(R5)
 C12 <- strassen(R3 , R6)
 C21 <- strassen(R6 , R2)
 R7 <- strassen(R3 , C21)
 C11 <- R1 - R7
 C22 <- -R6
 
 C <- rbind(cbind(C11,C12), cbind(C21,C22))

 return(C)
}



Function strassenInv3(A)


strassenInv3 <- function(A){

 div4 <- function(A, r){
  A <- list(A)
  A11 <- A[[1]][1:(r/2),1:(r/2)]
  A12 <- A[[1]][1:(r/2),(r/2+1):r]
  A21 <- A[[1]][(r/2+1):r,1:(r/2)]
  A22 <- A[[1]][(r/2+1):r,(r/2+1):r]
  A <- list(X11=A11, X12=A12, X21=A21, X22=A22)
  return(A)
 }

 strassen <- function(A, B){
  A <- div4(A, dim(A)[1])
  B <- div4(B, dim(B)[1])
  M1 <- (A$X11+A$X22) %*% (B$X11+B$X22)
  M2 <- (A$X21+A$X22) %*% B$X11
  M3 <- A$X11 %*% (B$X12-B$X22)
  M4 <- A$X22 %*% (B$X21-B$X11)
  M5 <- (A$X11+A$X12) %*% B$X22
  M6 <- (A$X21-A$X11) %*% (B$X11+B$X12)
  M7 <- (A$X12-A$X22) %*% (B$X21+B$X22)

  C11 <- M1+M4-M5+M7
  C12 <- M3+M5
  C21 <- M2+M4
  C22 <- M1-M2+M3+M6
 
  C <- rbind(cbind(C11,C12), cbind(C21,C22))
  return(C)
 }

 strassen2 <- function(A, B){
  A <- div4(A, dim(A)[1])
  B <- div4(B, dim(B)[1])
  M1 <- strassen((A$X11+A$X22) , (B$X11+B$X22))
  M2 <- strassen((A$X21+A$X22) , B$X11)
  M3 <- strassen(A$X11 , (B$X12-B$X22))
  M4 <- strassen(A$X22 , (B$X21-B$X11))
  M5 <- strassen((A$X11+A$X12) , B$X22)
  M6 <- strassen((A$X21-A$X11) , (B$X11+B$X12))
  M7 <- strassen((A$X12-A$X22) , (B$X21+B$X22))

  C11 <- M1+M4-M5+M7
  C12 <- M3+M5
  C21 <- M2+M4
  C22 <- M1-M2+M3+M6

  C <- rbind(cbind(C11,C12), cbind(C21,C22))
  return(C)
 }

        if (nrow(A) != ncol(A)) 
          { stop("only square matrices can be inverted") }

 is.wholenumber <-
     function(x, tol = .Machine$double.eps^0.5)  abs(x - round(x)) < tol

 if ( (is.wholenumber(log(nrow(A), 2)) != TRUE) || (is.wholenumber(log(ncol(A), 2)) != TRUE) )
   { stop("only square matrices of dimension 2^k * 2^k can be inverted with Strassen method") }

 A <- div4(A, dim(A)[1])

 R1 <- strassenInv2(A$X11)
 R2 <- strassen2(A$X21 , R1)
 R3 <- strassen2(R1 , A$X12)
 R4 <- strassen2(A$X21 , R3)
 R5 <- R4 - A$X22
 R6 <- strassenInv2(R5)
 C12 <- strassen2(R3 , R6)
 C21 <- strassen2(R6 , R2)
 R7 <- strassen2(R3 , C21)
 C11 <- R1 - R7
 C22 <- -R6
 
 C <- rbind(cbind(C11,C12), cbind(C21,C22))

 return(C)
}



We run now some test. First check if the function successfully invert the matrix and compare them with the results of the standard R function (Function solve()):

A <- matrix(trunc(rnorm(512*512)*100), 512,512)

all( round(solve(A),8) == round(strassenInv(A),8) )
[1] TRUE

all( round(solve(A),8) == round(strassenInv2(A),8) )
[1] TRUE

all( round(solve(A),6) == round(strassenInv3(A),6) )
[1] TRUE


The function performs the operations correctly. But there is a problem of approximation: in fact the first two functions are accurate to the eighth decimal place, while the third through sixth. Probably not an issue of calculus, but it is a problem of expression of numbers in binary format and 32-bit, which causes these errors.

Now we analyze the computation time. See in the table the result, obtained by running the following code:

Time computation


A <- matrix(trunc(rnorm(512*512)*100), 512,512)
system.time(solve(A))
system.time(strassenInv(A))
system.time(strassenInv2(A))
system.time(strassenInv3(A))

A <- matrix(trunc(rnorm(1024*1024)*100), 1024,1024)
system.time(solve(A))
system.time(strassenInv(A))
system.time(strassenInv2(A))
system.time(strassenInv3(A))

A <- matrix(trunc(rnorm(2048*2048)*100), 2048,2048)
system.time(solve(A))
system.time(strassenInv(A))
system.time(strassenInv2(A))
system.time(strassenInv3(A))

A <- matrix(trunc(rnorm(4096*4096)*100), 4096,4096)
system.time(solve(A))
system.time(strassenInv(A))
system.time(strassenInv2(A))
system.time(strassenInv3(A))





The results are quite obvious, and using a modification of Strassen algorithm for matrix inversion, there is a real time saving.

Please, remember these two recommendations already made:
- The code is to be improved, and if anyone wants to help me, I will be happy to update my code
- If you consider it useful to use these function for any work, a citation is always welcome (contact me at my e-mail for details)

To leave a comment for the author, please follow the link and comment on their blog: Statistic on aiR.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.