Skip to content

Instantly share code, notes, and snippets.

@rjhall
Created July 25, 2014 13:43
Show Gist options
  • Save rjhall/b0f702f1ff97ee10a8e7 to your computer and use it in GitHub Desktop.
Save rjhall/b0f702f1ff97ee10a8e7 to your computer and use it in GitHub Desktop.
package com.etsy.scalding.jobs.conjecture
import com.etsy.scalding.conjecture.NNMF
import com.twitter.scalding.{Args, Job, Tsv, SequenceFile}
import org.apache.commons.math3.linear.RealVector
class NNMFTest(args : Args) extends Job(args) {
val iter = args.getOrElse("iter", "0").toInt
val iters = args.getOrElse("iters", "20").toInt
val base_dir = args.getOrElse("base_dir", "nnmf_test")
val A_path = args.getOrElse("A", "critics.tsv")
val A = Tsv(A_path, ('row, 'col, 'val))
.map('val -> 'val){v : String => v.toDouble}
val HW = if(iter == 0) {
// just initialize
NNMF.initGaussian(A, 10)
} else {
// Last iterations output.
(SequenceFile(base_dir + "/H/" + (iter-1), ('row, 'vec)).read,
SequenceFile(base_dir + "/W/" + (iter-1), ('col, 'vec)).read)
}
val HW_ = NNMF.updateGaussian(A, HW._1, HW._2)
HW_._1.write(SequenceFile(base_dir + "/H/" + iter))
HW_._2.write(SequenceFile(base_dir + "/W/" + iter))
HW._1.crossWithSmaller(HW._2.rename('vec -> 'vec2))
.map(('vec, 'vec2) -> 'pred){x : (RealVector, RealVector) => x._1.dotProduct(x._2)}
.project('row, 'col, 'pred)
.joinWithSmaller(('row, 'col) -> ('row_, 'col_), A.rename(('row, 'col) -> ('row_, 'col_)), new cascading.pipe.joiner.OuterJoin())
.mapTo(('val, 'pred) -> 'err){x : (Double, Double) => val d = x._1 - x._2; d*d}
.groupAll{_.average('err)}
.write(Tsv(base_dir+"/err/"+iter))
// Start more iterations possibly.
override def next : Option[Job] = {
val new_args = args + (("iter", Some((iter+1).toString)))
if(iter < iters - 1) {
Some(clone(new_args))
} else {
None
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment