Skip to content

Instantly share code, notes, and snippets.

@Jonty800
Created August 22, 2017 14:42
Show Gist options
  • Save Jonty800/682d3d01e78a58e8fc862b272c741789 to your computer and use it in GitHub Desktop.
Save Jonty800/682d3d01e78a58e8fc862b272c741789 to your computer and use it in GitHub Desktop.
package com.ukc.deeplearning;
/**
* Created by Jon Baker on 20/08/2017. <Part of Socialsense> Copyright University of Kent
*/
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RBM;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import java.io.File;
import java.io.IOException;
public class DeepLearning {
public static void main(String[] args) throws Exception {
int labelIndex = 0;
int numClasses = 28;
int batchSizeTraining = 2828; //entire training size
DataSet trainingData = readCSVDataset(
"csv/train.csv",
batchSizeTraining, labelIndex, numClasses);
// this is the data we want to classify
int batchSizeTest = 11605;
DataSet testData = readCSVDataset("csv/eval.csv",
batchSizeTest, labelIndex, numClasses);
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
trainingData.scale();
testData.scale();
//run the model
MultiLayerNetwork model = buildModel();
model.fit(trainingData);
//evaluate the model on the test set
Evaluation eval = new Evaluation(numClasses);
INDArray output = model.output(testData.getFeatureMatrix());
eval.eval(testData.getLabels(), output);
System.out.println(eval.stats());
}
public static MultiLayerNetwork buildModel() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.miniBatch(false)
.weightInit(WeightInit.RELU)
.iterations(10)
.learningRate(0.2)
.updater(Updater.ADAGRAD).gradientNormalization(GradientNormalization.ClipL2PerLayer)
.regularization(true).l2(1e-1).l1(1e-3)
.optimizationAlgo(OptimizationAlgorithm.LBFGS)
.list()
.layer(0, new RBM.Builder() //RBM is apparently 2-layer (1 visible 1 hidden)
.nIn(19) // Input nodes
.nOut(1024) // Output nodes
//.activation(Activation.RELU) // Activation function type
.weightInit(WeightInit.RELU) // Weight initialization
.visibleUnit(RBM.VisibleUnit.GAUSSIAN)
.hiddenUnit(RBM.HiddenUnit.RECTIFIED)
.build())
.layer(1, new RBM.Builder()
.nIn(1024) // Input nodes
.nOut(1024) // Output nodes
.activation(Activation.RELU) // Activation function type
.weightInit(WeightInit.RELU) // Weight initialization
.hiddenUnit(RBM.HiddenUnit.RECTIFIED)
.build())
.layer(2, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
.weightInit(WeightInit.RELU)
.activation(Activation.SOFTMAX).nIn(1024).nOut(28).build())
.backprop(true).pretrain(false)
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(10));
return net;
}
/**
* used for testing and training
*
* @param csvFileClasspath
* @param batchSize
* @param labelIndex
* @param numClasses
* @return
* @throws IOException
* @throws InterruptedException
*/
private static DataSet readCSVDataset(
String csvFileClasspath, int batchSize, int labelIndex, int numClasses)
throws IOException, InterruptedException {
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new File(csvFileClasspath)));
DataSetIterator iterator = new RecordReaderDataSetIterator(rr, batchSize, labelIndex, numClasses);
return iterator.next();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment