Skip to content

Instantly share code, notes, and snippets.

@mindcrime
Created February 26, 2017 18:59
Show Gist options
  • Save mindcrime/f30b6641a58ecf6f059db9380a865f86 to your computer and use it in GitHub Desktop.
Save mindcrime/f30b6641a58ecf6f059db9380a865f86 to your computer and use it in GitHub Desktop.
Complete example where values seem to not get scaled.
package org.fogbeam.dl4j.spark;
import java.io.File;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.input.PortableDataStream;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.writable.Writable;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.spark.functions.RecordReaderFunction;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.datavec.DataVecDataSetFunction;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import com.google.common.base.Stopwatch;
public class ExpMain2
{
public static void main(String[] args) throws Exception
{
SparkConf sparkConf = new SparkConf();
sparkConf.setMaster("local");
sparkConf.setAppName("SparkNeuralNetwork");
Stopwatch sw = Stopwatch.createStarted();
JavaSparkContext sc = new JavaSparkContext( sparkConf );
sc.hadoopConfiguration().set("mapreduce.input.fileinputformat.input.dir.recursive", "true");
JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles("/home/prhodes/development/experimental/ai_exp/NeuralNetworkSandbox/mnist_png/cutdown/0/**");
ImageRecordReader irr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator() );
List<String> labelsList = Arrays.asList( "0", "1", "2", "3", "4", "5", "6", "7", "8", "9" );
irr.setLabels(labelsList);
RecordReaderFunction rrf = new RecordReaderFunction(irr);
JavaRDD<List<Writable>> rdd = origData.map(rrf);
System.out.println( "DataSet RDD created");
DataNormalization scaler = new ImagePreProcessingScaler(0,1);
JavaRDD<DataSet> trainingData = rdd.map(new DataVecDataSetFunction(1,10, false, scaler, null ));
trainingData.foreach( new VoidFunction<DataSet>() {
int count = 0;
@Override
public void call(DataSet arg0) throws Exception {
System.out.println( "count: " + count++ + "\n");
System.out.println( "features: " + arg0.getFeatures() + "\n");
}
} );
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(10)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.learningRate(0.02)
.updater(Updater.NESTEROVS).momentum(0.9)
.regularization(true).l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build())
.layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX).nIn(100).nOut(10).build())
.pretrain(false).backprop(true)
.setInputType(InputType.convolutional(28, 28, 1))
.build();
// Create the TrainingMaster instance
int examplesPerDataSetObject = 1;
TrainingMaster trainingMaster = new ParameterAveragingTrainingMaster.Builder(examplesPerDataSetObject)
.build();
// Create the SparkDl4jMultiLayer instance
// Create the Spark network
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf, trainingMaster);
long elapsedPhase1 = sw.elapsed(TimeUnit.SECONDS);
System.out.println( "Loading data took " + elapsedPhase1 + " seconds. Starting to train model now.");
MultiLayerNetwork trainedNetwork = null;
for( int i = 0; i < 20; i++ )
{
trainedNetwork = sparkNet.fit( trainingData );
}
long elapsedPhase2 = sw.elapsed(TimeUnit.SECONDS);
System.out.println( "Training model took " + ( elapsedPhase2 - elapsedPhase1) + " seconds.");
System.out.println( "Total elapsed time: " + elapsedPhase2 );
/* delete any existing model if there is one */
File oldModelFile = new File( "sparkTrainedNetwork.zip" );
if( oldModelFile.exists())
{
oldModelFile.delete();
oldModelFile = null;
}
ModelSerializer.writeModel(trainedNetwork, new File("sparkTrainedNetwork.zip"), false);
System.out.println( "done" );
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment