Last active
December 14, 2015 11:58
-
-
Save aldente39/5082972 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object Strassen { | |
def add0(m:Array[Array[Double]], size:Int) = { | |
val log2 = Math.ceil(Math.log(size) / Math.log(2)).toInt | |
val next = Math.pow(2, log2).toInt | |
val res = Array.ofDim[Double](next, next) | |
for (i <- 0 until m.length) { | |
for (j <- 0 until m(0).length) { | |
res(i)(j) = m(i)(j) | |
} | |
} | |
res | |
} | |
def adda(a1:Array[Array[Double]], a2:Array[Array[Double]]) = { | |
val row = a1.length | |
val col = a1(0).length | |
val res = Array.ofDim[Double](row, col) | |
for (i <- 0 until row) { | |
for (j <- 0 until col) { | |
res(i)(j) = a1(i)(j) + a2(i)(j) | |
} | |
} | |
res | |
} | |
def suba(a1:Array[Array[Double]], a2:Array[Array[Double]]) = { | |
val row = a1.length | |
val col = a1(0).length | |
val res = Array.ofDim[Double](row, col) | |
for (i <- 0 until row) { | |
for (j <- 0 until col) { | |
res(i)(j) = a1(i)(j) - a2(i)(j) | |
} | |
} | |
res | |
} | |
def ikj(a1:Array[Array[Double]], a2:Array[Array[Double]]) = { | |
val row = a1.length | |
val col = a2(0).length | |
val res = Array.ofDim[Double](row, col) | |
for (i <- 0 until row) { | |
for (k <- 0 until a2.length) { | |
for (j <- 0 until col) { | |
res(i)(j) += a1(i)(k) * a2(k)(j) | |
} | |
} | |
} | |
res | |
} | |
def ijk(a1:Array[Array[Double]], a2:Array[Array[Double]]) = { | |
val row = a1.length | |
val col = a2(0).length | |
val res = Array.ofDim[Double](row, col) | |
for (i <- 0 until row) { | |
for (j <- 0 until col) { | |
for (k <- 0 until a2.length) { | |
res(i)(j) += a1(i)(k) * a2(k)(j) | |
} | |
} | |
} | |
res | |
} | |
def strassen_r(a:Array[Array[Double]], b:Array[Array[Double]], leaf:Int):Array[Array[Double]] = { | |
if (a.length <= leaf) { | |
ikj(a, b) | |
} | |
else { | |
val ns = a.length / 2 | |
val a11 = Array.ofDim[Double](ns, ns) | |
val a12 = Array.ofDim[Double](ns, ns) | |
val a21 = Array.ofDim[Double](ns, ns) | |
val a22 = Array.ofDim[Double](ns, ns) | |
val b11 = Array.ofDim[Double](ns, ns) | |
val b12 = Array.ofDim[Double](ns, ns) | |
val b21 = Array.ofDim[Double](ns, ns) | |
val b22 = Array.ofDim[Double](ns, ns) | |
for (i <- 0 until ns) { | |
for (j <- 0 until ns) { | |
a11(i)(j) = a(i)(j) | |
a12(i)(j) = a(i)(j + ns) | |
a21(i)(j) = a(i + ns)(j) | |
a22(i)(j) = a(i + ns)(j + ns) | |
b11(i)(j) = b(i)(j) | |
b12(i)(j) = b(i)(j + ns) | |
b21(i)(j) = b(i + ns)(j) | |
b22(i)(j) = b(i + ns)(j + ns) | |
} | |
} | |
val p1 = strassen_r (adda (a11, a22), adda (b11, b22), leaf) | |
val p2 = strassen_r (adda (a21, a22), b11, leaf) | |
val p3 = strassen_r (a11, suba (b12, b22), leaf) | |
val p4 = strassen_r (a22, suba (b21, b11), leaf) | |
val p5 = strassen_r (adda (a11, a12), b22, leaf) | |
val p6 = strassen_r (suba (a21, a11), adda (b11, b12), leaf) | |
val p7 = strassen_r (suba (a12, a22), adda (b21, b22), leaf) | |
val c11 = suba (adda (p1, p4), adda (p5, p7)) | |
val c12 = adda (p3, p5) | |
val c21 = adda (p2, p4) | |
val c22 = suba (adda (p1, p3), adda (p2, p6)) | |
val res = Array.ofDim[Double](a.length,b(0).length) | |
for (i <- 0 until ns) { | |
for (j <- 0 until ns) { | |
res(i)(j) = c11(i)(j) | |
res(i)(j + ns) = c12(i)(j) | |
res(i + ns)(j) = c21(i)(j) | |
res(i + ns)(j + ns) = c22(i)(j) | |
} | |
} | |
res | |
} | |
} | |
def strassen(m1:Array[Array[Double]], m2:Array[Array[Double]], leaf:Int = 128) = { | |
val maxSize = List(m1.length, m1(0).length, m2.length, m2(0).length).max | |
val a = add0(m1, maxSize) | |
val b = add0(m2, maxSize) | |
val res = strassen_r(a, b, leaf) | |
res | |
} | |
def time(proc: => Unit) = { | |
val start = System.currentTimeMillis | |
proc | |
println((System.currentTimeMillis - start) + "msec.") | |
} | |
def main(args:Array[String]):Unit = { | |
val a = Array.ofDim[Double](1024, 1024) | |
time(strassen(a, a, 256)) | |
time(ikj(a, a)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment