Skip to content

Instantly share code, notes, and snippets.

@chy168
Created March 5, 2016 03:24
Show Gist options
  • Save chy168/c599a2b2f85697a942e6 to your computer and use it in GitHub Desktop.
Save chy168/c599a2b2f85697a942e6 to your computer and use it in GitHub Desktop.
Spark Mllib K-means example
// k-mean
import org.apache.spark.mllib.clustering.{KMeans, KMeansModel}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
case class Data(userId: String, videoId: String, isKid: String)
val raw_data = sc.textFile("/Users/zzchen/Downloads/locked_video_up5.csv").map { line =>
val f = line.split(",")
Data(f(0),f(1).split("=")(1),f(5))
}
/*
val df = raw_data.toDF().groupBy("_1").pivot("_2").count()
df.show()
*/
raw_data.map(_.userId).distinct().zipWithUniqueId().map{ case (a, b) => (a, b)}.toDF().registerTempTable("user")
raw_data.toDF().registerTempTable("raw")
//+--------------------+---+--------------------+-----------+-----+
//| _1| _2| userId| videoId|isKid|
//+--------------------+---+--------------------+-----------+-----+
//|ooxxxxxxxxxxxxxx@...| 6|ooxxxxxxxxxxxxxx@...|ToKMrLuH_iM| Yes|
//+--------------------+---+--------------------+-----------+-----+
val joined_data = sqlContext.sql("SELECT * FROM user u JOIN raw r ON u.`_1` = r.userId")
val df3 = joined_data.groupBy("_2").pivot("videoId").count()
val input_data = df3.map{r =>
val ary = r.toSeq.map({ case col: Long => col.toDouble }).toArray
Vectors.dense(ary.slice(1, ary.length).map(s=>s))
//Vectors.dense(ary)
//(ary.slice(1, ary.length).map(s=>s))
//println(ary.slice(1, ary.length))
}
df3.write.format("com.databricks.spark.csv").option("header", "true").save("GGGG.csv")
val numClusters = 6
val numIterations = 20
val clusters = KMeans.train(input_data, numClusters, numIterations)
val WSSSE = clusters.computeCost(input_data)
input_data.map{ in =>
println("Predict: " + in + " : " + clusters.predict(in))
}.collect()
clusters.predict(Vectors.dense(0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0))
clusters.predict(Vectors.dense(0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment