Skip to content

Instantly share code, notes, and snippets.

@alpicola
Created July 28, 2014 04:41
Show Gist options
  • Save alpicola/da5db098d2d5bac274ad to your computer and use it in GitHub Desktop.
Save alpicola/da5db098d2d5bac274ad to your computer and use it in GitHub Desktop.
Biterm Topic Model
// X. Yan, J. Guo, Y. Lan, and X. Cheng, A Biterm Topic Model for Short Texts,
// in WWW. ACM, 2013, pp. 1445–145
import scala.collection._
import scala.io.Source
import scala.util.Random
import java.io._
class BTM(val alpha:Double, val beta:Double, val k:Int, val iterN:Int) {
private var words:Array[String] = null
private var documents:Array[(String, Array[Int])] = null
private var biterms:Array[(Int, Int)] = null
private var m:Int = 0
private var b_z:Array[Int] = null
private var n_z:Array[Long] = null
private var n_w_z:Array[Long] = null
private var theta:Array[Double] = null
private var phi:Array[Double] = null
private var table:Array[Double] = null
def load(file:String) {
val dict = mutable.HashMap[String, Int]()
val count = Iterator.from(0)
val buf1 = mutable.ArrayBuffer[(String, Array[Int])]()
val buf2 = mutable.ArrayBuffer[(Int, Int)]()
val s = Source.fromFile(file)
try {
s.getLines.foreach { line =>
val row = line.stripLineEnd.split("\t")
val d = row.tail.map(word => dict.getOrElseUpdate(word, count.next))
buf1 += ((row.head, d))
buf2 ++= getBiterms(d)
}
} finally {
s.close
}
m = count.next
words = new Array(m)
dict.iterator.foreach { case (word, i) =>
words(i) = word
}
documents = buf1.toArray
biterms = buf2.toArray
b_z = new Array(biterms.length)
n_z = new Array(k)
n_w_z = new Array(k * m)
theta = new Array(k)
phi = new Array(k * m)
table = new Array(k)
println(s"|B|: ${biterms.length}, K: $k, M: $m")
}
def estimate {
Iterator.continually(0L).copyToArray(n_z)
Iterator.continually(0L).copyToArray(n_w_z)
biterms.iterator.zipWithIndex.foreach { case (b, i) =>
setTopic(b, i, Random.nextInt(k))
}
Iterator.range(0, iterN).foreach { n =>
println(s"iteration ${n+1}")
biterms.iterator.zipWithIndex.foreach { case (b, i) =>
unsetTopic(b, b_z(i))
setTopic(b, i, sampleTopic(b))
}
}
calcTheta
calcPhi
println("done!")
}
def report {
val o1 = new PrintWriter(new File(s"topics.k$k"))
Iterator.range(0, k).foreach { z =>
val ws = (0 until m).sortBy(w => -phi(w*k+z)).take(20)
o1.println(s"${theta(z)}\t" ++ ws.map(words).mkString("\t"))
}
o1.close
val o2 = new PrintWriter(new File(s"words.k$k"))
Iterator.range(0, m).foreach { w =>
val p_w_z = Iterator.range(w*k, (w+1)*k-1).map(phi)
val weight = p_w_z.zip(theta.iterator).map { case (p, q) => p * q }.toArray
val h = 1.0 / weight.sum
val p_z_w = weight.iterator.map(_ * h)
o2.println(s"${words(w)}\t" ++ p_z_w.mkString("\t"))
}
o2.close
val o3 = new PrintWriter(new File(s"documents.k$k"))
documents.foreach { case (id, d) =>
val bs = getBiterms(d).toArray
val hs = bs.map { b =>
val (w1, w2) = b
1.0 / Iterator.range(0, k).map { z =>
theta(z) * phi(w1*k+z) * phi(w2*k+z)
}.sum * bs.count(_ == b) / bs.length
}
val p_z_d = Iterator.range(0, k).map { z =>
bs.iterator.zip(hs.iterator).map { case (b, h) =>
val (w1, w2) = b
theta(z) * phi(w1*k+z) * phi(w2*k+z) * h
}.sum
}
o3.println(s"${id}\t" ++ p_z_d.mkString("\t"))
}
o3.close
}
private def getBiterms(d:Array[Int]):Iterator[(Int, Int)] = {
d.toSeq.combinations(2).map { case Seq(w1, w2) =>
if (w1 < w2) (w1, w2) else (w2, w1)
}
}
private def setTopic(b:(Int, Int), i:Int, z:Int) {
val (w1, w2) = b
b_z(i) = z
n_z(z) += 1
n_w_z(w1*k+z) += 1
n_w_z(w2*k+z) += 1
}
private def unsetTopic(b:(Int, Int), z:Int) {
val (w1, w2) = b
n_z(z) -= 1
n_w_z(w1*k+z) -= 1
n_w_z(w2*k+z) -= 1
}
private def sampleTopic(b:(Int, Int)):Int = {
val (w1, w2) = b
Iterator.range(0, k).map { z =>
val h = m / (n_z(z) * 2 + m * beta)
val p_z_w1 = (n_w_z(w1*k+z) + beta) * h
val p_z_w2 = (n_w_z(w2*k+z) + beta) * h
(n_z(z) + alpha) * p_z_w1 * p_z_w2
}.scanLeft(0.0)(_ + _).drop(1).copyToArray(table)
val r = Random.nextDouble * table.last
table.indexWhere(_ >= r)
}
private def calcTheta {
Iterator.range(0, k).map { z =>
(n_z(z) + alpha) / (biterms.length + k * alpha)
}.copyToArray(theta)
}
private def calcPhi {
Iterator.range(0, m).flatMap { w =>
Iterator.range(0, k).map { z =>
(n_w_z(w*k+z) + beta) / (n_z(z) * 2 + m * beta)
}
}.copyToArray(phi)
}
}
object BTM {
def main(args:Array[String]) {
val btm = new BTM(1.0 / 20, 0.01, 20, 200)
btm.load(args(0))
btm.estimate
btm.report
}
}
# coding: utf-8
require 'mongo'
require 'MeCab'
require 'dotenv'
Dotenv.load
db = Mongo::Connection.new.db(ENV['MONGODB_DB'])
collection = db.collection(ENV['MONGODB_COLLECTION'])
mecab = MeCab::Tagger.new('-d /usr/share/mecab/dic/ipadic')
stopwords = []
open('stopwords.txt') {|f|
f.each_line {|line| stopwords << line.chomp }
}
open('tweets.tsv', 'w') {|f|
collection.find.each do |status|
if status['retweeted_status']
status = status['retweeted_status']
end
text = status['text']
mentions = []
hashtags = []
domains = []
status['entities']['user_mentions'].each do |item|
mentions << '@' + item['screen_name']
text.sub!(mentions.last, '')
end
status['entities']['hashtags'].each do |item|
hashtags << '#' + item['text']
text.sub!(hashtags.last, '')
end
status['entities']['urls'].each do |item|
domains << item['display_url'].split('/')[0]
text.sub!(item['url'], '')
end
(status['entities']['media'] || []).each do |item|
text.sub!(item['url'], '')
end
words = []
node = mecab.parseToNode(text)
while node
word = nil
feature = node.feature.split(',')
case feature[0]
when '名詞'
word = node.surface
when '動詞', '形容詞', '形容動詞'
if feature[6] != '*'
word = feature[6]
end
end
if word && !stopwords.include?(word)
words << word
end
node = node.next
end
words = words + mentions + hashtags + domains
if words.length > 1
words.unshift(status['id_str'])
f.puts words.join("\t")
end
end
}
# coding: utf-8
require 'tweetstream'
require 'mongo'
require 'dotenv'
Dotenv.load
db = Mongo::Connection.new.db(ENV['MONGODB_DB'])
collection = db.collection(ENV['MONGODB_COLLECTION'])
TweetStream.configure do |config|
config.consumer_key = ENV['CONSUMER_KEY']
config.consumer_secret = ENV['CONSUMER_SECRET']
config.oauth_token = ENV['ACCESS_TOKEN_KEY']
config.oauth_token_secret = ENV['ACCESS_SECRET']
config.auth_method = :oauth
end
count = 0
limit = 100000
TweetStream::Client.new.sample do |status|
if status.user.lang == 'ja'
collection.insert(status.to_h)
count += 1
if count % 100 == 0
puts "saved #{count} tweets"
if count >= limit
puts "done!"
exit
end
end
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment