Created
June 30, 2015 13:32
-
-
Save reisepass/e1b3b48bbbe8217fbfe6 to your computer and use it in GitHub Desktop.
Core portion of the Mean Field algorithm. Prettymuch one to one of the Matlab code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
val Q = DenseMatrix.ones[Double](numRegions, numClasses) | |
var lastMaxE = 0.0 | |
var lastMinE = 0.0 | |
var numNoChange = 0 | |
for (iter <- 0 until maxIterations) { | |
var numUnchangedQs = 0 | |
val lastQ = Q; | |
val xiLab = (0 until numClasses).par | |
val allXiperLabel = xiLab.map(curLab => ((curLab, | |
for (xi <- 0 until graph.size) yield { | |
val neigh = graph.getC(xi).toArray | |
val allClasses = (0 until numClasses).toList | |
val newQest = neigh.toList.map { neighIdx => | |
allClasses.foldLeft(0.0) { (running, curClass) => | |
{ | |
running + Math.exp(lastQ(neighIdx, curClass) * (if (DISABLE_PAIRWISE) 0 else thetaPairwise(curClass, curLab))) * Math.exp((1 / temp) * thetaUnary(xi, curLab)) | |
} | |
} | |
}.sum | |
(1 / temp) * newQest | |
}))) | |
for (labAgain <- 0 until numClasses) { | |
val allXi = allXiperLabel(labAgain)._2.toArray | |
Q(::, labAgain) := DenseVector(allXi) | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment