Last active
June 11, 2017 12:56
-
-
Save yukoba/af768495e8ba07cf126b51a788ff58c7 to your computer and use it in GitHub Desktop.
Tensor library. Only + is implemented. Broadcasting is supported.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package jp.yukoba.tensor | |
import java.util | |
import com.github.fommil.netlib.BLAS | |
import org.scalatest.FunSuite | |
class TensorFloat(val data: Array[Float], val shape: Array[Int], val strides: Array[Int]) { | |
assert(data.length > 0) | |
def ndim: Int = shape.length | |
def size: Int = shape.product | |
override def toString: String = { | |
val sb = new StringBuilder | |
sb.append("TensorFloat(") | |
if (ndim == 0) { | |
sb.append(data(0)) | |
} else { | |
def loop(dim: Int, idx: Int): Unit = { | |
if (dim == shape.length - 1) { | |
for (i <- 0 until shape(dim)) { | |
sb.append(data(idx + strides(dim) * i)).append(", ") | |
} | |
} else { | |
for (i <- 0 until shape(dim)) { | |
sb.append("(") | |
loop(dim + 1, idx + strides(dim) * i) | |
sb.delete(sb.size - 2, sb.size).append("), ") | |
} | |
} | |
} | |
loop(0, 0) | |
sb.delete(sb.size - 2, sb.size) | |
} | |
sb.append(")") | |
sb.toString | |
} | |
def broadcast(shape1: Array[Int], shape2: Array[Int], | |
strides1: Array[Int], strides2: Array[Int]): (Array[Int], Array[Int], Array[Int]) = { | |
if (util.Arrays.equals(shape1, shape2)) { | |
(shape1, strides1, strides2) | |
} else if (shape1.length > shape2.length) { | |
val x = broadcast(shape2, shape1, strides2, strides1) | |
(x._1, x._3, x._2) | |
} else { | |
if (shape1.length == 0) { | |
(shape2, new Array(shape2.length), strides2) | |
} else { | |
val shape1_2 = Array.fill(shape2.length - shape1.length)(1) ++ shape1 | |
val strides1_2 = new Array(shape2.length - shape1.length) ++ strides1 | |
val strides2_2 = strides2.clone() | |
for (i <- shape1_2.indices) { | |
if (shape1_2(i) != shape2(i)) { | |
if (shape1_2(i) == 1) { | |
shape1_2(i) = shape2(i) | |
strides1_2(i) = 0 | |
} else if (shape2(i) == 1) { | |
strides2_2(i) = 0 | |
} else assert(false) | |
} | |
} | |
(shape1_2, strides1_2, strides2_2) | |
} | |
} | |
} | |
def +(that: TensorFloat): TensorFloat = { | |
val data1 = data | |
val data2 = that.data | |
val (shape1, strides1, strides2) = broadcast(shape, that.shape, strides, that.strides) | |
val size1 = shape1.product | |
if (shape1.length == 0) { | |
new TensorFloat(Array(data1(0) + data2(0)), shape1, strides1) | |
} else { | |
val ary = new Array[Float](size1) | |
var aryIdx = 0 | |
def loop(dim: Int, idx1: Int, idx2: Int): Unit = { | |
if (dim == shape1.length - 1) { | |
// Warning: The current Java VM does auto-vectorization. I should write a pure Java code here and should not use BLAS. | |
val blas = BLAS.getInstance() | |
blas.scopy(shape1(dim), data1, idx1, strides1(dim), ary, aryIdx, 1) | |
blas.saxpy(shape1(dim), 1f, data2, idx2, strides2(dim), ary, aryIdx, 1) | |
aryIdx += shape1(dim) | |
} else { | |
for (i <- 0 until shape1(dim)) { | |
loop(dim + 1, idx1 + i * strides1(dim), idx2 + i * strides2(dim)) | |
} | |
} | |
} | |
loop(0, 0, 0) | |
new TensorFloat(ary, shape1, shape1.tail :+ 1) | |
} | |
} | |
} | |
object TensorFloat { | |
def apply(values: Float*): TensorFloat = { | |
assert(values.nonEmpty) | |
if (values.size == 1) | |
new TensorFloat(Array(values(0)), Array(), Array()) | |
else | |
TensorFloat(values.toArray) | |
} | |
def apply(ary: Array[Float]): TensorFloat = { | |
assert(ary.length > 0) | |
new TensorFloat(ary, Array(ary.length), Array(1)) | |
} | |
def apply(ary: Array[Array[Float]]): TensorFloat = { | |
assert(ary.length > 0) | |
assert(ary.forall(_.length == ary(0).length)) | |
val shape = Array(ary.length, ary(0).length) | |
val strides = Array(shape(1), 1) | |
new TensorFloat(ary.flatten, shape, strides) | |
} | |
implicit class ScalaFloat(val v: Float) extends TensorFloat(Array(v), Array(), Array()) | |
} | |
class TensorFloatTest extends FunSuite { | |
test("vector") { | |
val t1 = TensorFloat(1f, 2f, 3f) | |
val t2 = TensorFloat(4f, 5f, 6f) | |
val t3 = t1 + t2 | |
println(t3) | |
val t4 = 10f + t3 | |
println(t4) | |
} | |
test("matrix") { | |
val t1 = TensorFloat(Array(Array(1f, 2f), Array(3f, 4f))) | |
val t2 = TensorFloat(Array(Array(2f, 3f), Array(4f, 5f))) | |
val t3 = t1 + t2 | |
println(t3) | |
val t4 = 10f + t3 | |
println(t4) | |
} | |
test("broadcast") { | |
val t1 = TensorFloat(Array(Array(1f, 2f), Array(3f, 4f))) | |
val t2 = TensorFloat(5f, 6f) | |
val t3 = t1 + t2 | |
println(t3) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment