Created
August 24, 2018 00:55
-
-
Save albrzykowski/c44834f2e3fc8049bbf26e710656a6d8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.Arrays; | |
import java.util.List; | |
import org.apache.hadoop.yarn.webapp.hamlet.HamletSpec.P; | |
import org.apache.spark.SparkConf; | |
import org.apache.spark.api.java.JavaSparkContext; | |
import org.apache.spark.api.java.function.MapFunction; | |
import org.apache.spark.ml.Pipeline; | |
import org.apache.spark.ml.PipelineModel; | |
import org.apache.spark.ml.PipelineStage; | |
import org.apache.spark.ml.classification.LogisticRegression; | |
import org.apache.spark.ml.classification.LogisticRegressionModel; | |
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; | |
import org.apache.spark.ml.feature.Tokenizer; | |
import org.apache.spark.ml.linalg.Matrix; | |
import org.apache.spark.ml.param.ParamMap; | |
import org.apache.spark.ml.tuning.CrossValidator; | |
import org.apache.spark.ml.tuning.CrossValidatorModel; | |
import org.apache.spark.ml.tuning.ParamGridBuilder; | |
import org.apache.spark.sql.Column; | |
import org.apache.spark.sql.Dataset; | |
import org.apache.spark.sql.RelationalGroupedDataset; | |
import org.apache.spark.sql.Row; | |
import org.apache.spark.sql.SQLContext; | |
import org.apache.spark.sql.SparkSession; | |
import org.apache.spark.sql.api.java.UDF1; | |
import org.apache.spark.sql.catalyst.expressions.Randn; | |
import org.apache.spark.sql.expressions.Window; | |
import org.apache.spark.sql.expressions.WindowSpec; | |
import org.apache.spark.sql.types.DataTypes; | |
import org.apache.spark.sql.types.StructType; | |
import org.netlib.util.doubleW; | |
import breeze.linalg.randn; | |
import scala.Tuple2; | |
import org.apache.spark.api.java.JavaPairRDD; | |
import org.apache.spark.api.java.JavaRDD; | |
import org.apache.spark.mllib.evaluation.MulticlassMetrics; | |
import org.apache.spark.ml.feature.HashingTF; | |
import org.apache.spark.ml.feature.IDF; | |
import org.apache.spark.ml.feature.LabeledPoint; | |
import org.apache.spark.ml.feature.StopWordsRemover; | |
import org.apache.spark.ml.feature.StringIndexer; | |
import static org.apache.spark.sql.functions.*; | |
public class App { | |
public static void main( String[] args ) { | |
SparkSession spark = SparkSession | |
.builder() | |
.appName("Java Spark SQL Example") | |
.getOrCreate(); | |
StructType schema = new StructType() | |
.add("word", "string") | |
.add("polarity", "double") | |
.add("category", "string"); | |
Dataset<Row> df = spark.read() | |
.option("mode", "DROPMALFORMED") | |
.option("delimiter", "\t") | |
.option("header", "true") | |
.schema(schema) | |
.csv("src/main/resources/SEL-utf-8.txt"); | |
df.show(20); | |
Dataset<Row>[] split = df.orderBy(rand()).randomSplit(new double[] {0.7, 0.3}); | |
Dataset<Row> training = split[0]; | |
Dataset<Row> test = split[1]; | |
StringIndexer indexer = new StringIndexer() | |
.setInputCol("label") | |
.setOutputCol("labelIndexed"); | |
Tokenizer tokenizer = new Tokenizer() | |
.setInputCol("text") | |
.setOutputCol("tokens"); | |
StopWordsRemover stopWordsRemover = new StopWordsRemover() | |
.setInputCol("tokens") | |
.setOutputCol("cleardFromSopwords") | |
.setStopWords(StopWordsRemover.loadDefaultStopWords("english")); | |
HashingTF hashingTF = new HashingTF() | |
.setInputCol("cleardFromSopwords") | |
.setOutputCol("rawFeatures") | |
.setNumFeatures(50000); | |
IDF idf = new IDF() | |
.setInputCol("rawFeatures") | |
.setOutputCol("features"); | |
LogisticRegression lr = new LogisticRegression() | |
.setMaxIter(10) | |
.setRegParam(0.3) | |
.setFamily("multinomial") | |
.setLabelCol("labelIndexed"); | |
Pipeline pipeline = new Pipeline() | |
.setStages(new PipelineStage[] {indexer, tokenizer, stopWordsRemover, hashingTF, idf, lr}); | |
ParamMap[] paramGrid = new ParamGridBuilder() | |
.addGrid(lr.maxIter(), new int[] { 10, 20 }) | |
.addGrid(lr.regParam(), new double[] { 0.1, 1.0 }) | |
.addGrid(lr.elasticNetParam(), new double[] { 0.7 }) | |
.addGrid(hashingTF.numFeatures(), new int[] {50000}) | |
.build(); | |
MulticlassClassificationEvaluator mce = new MulticlassClassificationEvaluator() | |
.setLabelCol("labelIndexed") | |
.setPredictionCol("prediction") | |
.setMetricName("weightedPrecision"); | |
CrossValidator validator = new CrossValidator() | |
.setNumFolds(2) | |
.setEstimator(pipeline) | |
.setEvaluator(mce) | |
.setEstimatorParamMaps(paramGrid); | |
PipelineModel model = (PipelineModel) validator.fit(training).bestModel(); | |
try { | |
model.save("src/main/resources/model"); | |
} catch(Exception e) {} | |
Dataset<Row> predictions = model.transform(test); | |
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() | |
.setLabelCol("labelIndexed") | |
.setPredictionCol("prediction") | |
.setMetricName("weightedPrecision"); | |
double accuracy = evaluator.evaluate(predictions); | |
predictions | |
.withColumn("label", new Column("label")) | |
.withColumn("labelIndexed", new Column("labelIndexed")) | |
.withColumn("prediction", new Column("prediction")) | |
.withColumn("text", new Column("text")) | |
.select("label", "prediction", "labelIndexed", "text") | |
.show(500); | |
System.out.println("Weighted precision: " + accuracy); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment