Skip to content

Instantly share code, notes, and snippets.

@dmmiller612
Last active June 21, 2016 14:30
Show Gist options
  • Save dmmiller612/8a9a4f53a7b16eb6edff0efa200dc31a to your computer and use it in GitHub Desktop.
Save dmmiller612/8a9a4f53a7b16eb6edff0efa200dc31a to your computer and use it in GitHub Desktop.
Java Spark Gist for Linear Regression after a group by and conversion from JavaPairRdd<Something, Iterable<Something>> to Map<Something, LinearRegressionModel.
package cdapp.services;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.Lists;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.feature.OneHotEncoder;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LinearRegressionModel;
import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import weatherapp.models.CD;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* The main goal of this gist is to show how to convert PAIRRDD<Something, Iterable<Something>> to Map<String, JavaRDD<Something>>,
* then perform machine learning on the items that were grouped by. It also includes usage of pipelines, spark sqs, one hot encoding, etc.
* I am still adding to this gist.
*/
public class SparkCDGist {
final SparkConf conf = new SparkConf().setMaster("local[2]").setAppName("cds");
final JavaSparkContext sparkContext = new JavaSparkContext(conf);
final SQLContext sqlContext = new SQLContext(sparkContext);
/**
* creates a dataframe based on the incoming json, which is
* {"id": "String", "condition": "String", "location": "String", "price": "String"}
* @param rows
* @return
*/
private DataFrame createDataFrame(JavaRDD<Row> rows){
StructType schema = new StructType(new StructField[]{
new StructField("id", DataTypes.StringType, false, Metadata.empty()),
new StructField("condition", DataTypes.StringType, false, Metadata.empty()),
new StructField("location", DataTypes.StringType, false, Metadata.empty()),
new StructField("price", DataTypes.IntegerType, false, Metadata.empty())
});
return sqlContext.createDataFrame(rows, schema);
}
/**
* Properly indexes the strings for one hot encoding to take place
*/
private List<StringIndexer> getStringIndexers(List<String> keys){
return keys.stream()
.map(item -> new StringIndexer()
.setInputCol(item)
.setOutputCol(item + "_index"))
.collect(Collectors.toList());
}
/**
* Sets up one hot encoder for pipeline
*/
private List<OneHotEncoder> getOneHotEncoders(List<String> keys, String suffix){
return keys.stream().map(name -> new OneHotEncoder()
.setInputCol(name + suffix)
.setOutputCol(name + suffix + "_vec")) //name is originalName + _index + _vec
.collect(Collectors.toList());
}
/**
* creates the pipeline
*/
private Pipeline createPipeline(List<StringIndexer> stringIndexers, List<OneHotEncoder> oneHotEncoders, VectorAssembler vectorAssembler){
int totalSize = stringIndexers.size() + oneHotEncoders.size();
PipelineStage[] pipelineStages = new PipelineStage[totalSize + 1];
for (int i = 0; i < totalSize; i++){
if (i < stringIndexers.size()){
pipelineStages[i] = stringIndexers.get(i);
} else {
pipelineStages[i] = oneHotEncoders.get(i - stringIndexers.size());
}
}
pipelineStages[totalSize] = vectorAssembler;
return new Pipeline().setStages(pipelineStages);
}
/**
* calls handle pipeline and sets everything up.
*/
private Pipeline handleCds(){
List<String> names = Arrays.asList("condition", "location");
List<StringIndexer> keys = getStringIndexers(names);
List<OneHotEncoder> oneHotEncoders = getOneHotEncoders(names, "_index");
return createPipeline(keys, oneHotEncoders, new VectorAssembler()
.setInputCols(new String[]{"condition_index_vec", "location_index_vec"})
.setOutputCol("features"));
}
/**
* Creates a Linear Regression Model
*/
public LinearRegressionModel performMachineLearning(DataFrame df){
Pipeline pipeline = handleCds();
PipelineModel pm = pipeline.fit(df);
JavaRDD<Row> vecFeatures = pm.transform(df).select("features", "price").javaRDD();
JavaRDD<LabeledPoint> labeledPoints = vecFeatures.map(item -> new LabeledPoint(new Integer(item.getInt(item.size() - 1)).doubleValue(), (Vector)item.get(0)));
labeledPoints.cache();
int numIterations = 300;
double stepSize = 0.000001;
return LinearRegressionWithSGD.train(JavaRDD.toRDD(labeledPoints), numIterations, stepSize);
}
/**
* gets and RDD by key from a javaPairRDD
*/
private JavaRDD<Row> getRddByKey(JavaPairRDD<String, Iterable<Row>> pairRDD, String key) {
return pairRDD.filter(v -> v._1().equals(key)).values().flatMap(tuples -> tuples);
}
/**
* Main method that returns a map of the string as the key and value as the regression model.
*/
public Map<String, LinearRegressionModel> regressionModels(){
JavaPairRDD<String, Iterable<Row>> cds = getCds();
Map<String, LinearRegressionModel> rdds = new HashMap<>();
try {
List<String> keys = cds.keys().distinct().collect();
for (String key : keys){
JavaRDD<Row> rddByKey = getRddByKey(cds, key);
DataFrame df = createDataFrame(rddByKey);
LinearRegressionModel linearRegressionModel = performMachineLearning(df);
rdds.put(key, linearRegressionModel);
}
} catch (Exception e) {
System.out.println(e.toString());
}
return rdds;
}
/**
* Gets all the cds data from a hdfs.
*/
public JavaPairRDD<String, Iterable<Row>> getCds(){
return sparkContext.textFile("hdfs://localhost:9000/users/cd/*/")
.map(item -> {
Row row = null;
try {
ObjectMapper objectMapper = new ObjectMapper();
CD cd = objectMapper.readValue(item, CD.class);
row = RowFactory.create(cd.getId(), cd.getCondition(), cd.getLocation(), cd.getPrice());
} catch (Exception e){}
return row;
})
.groupBy(item -> item.getString(0));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment