Skip to content

Instantly share code, notes, and snippets.

@ABeltramo
Created October 16, 2019 14:41
Show Gist options
  • Save ABeltramo/51eb2bdea3a34028b157078d6f43463d to your computer and use it in GitHub Desktop.
Save ABeltramo/51eb2bdea3a34028b157078d6f43463d to your computer and use it in GitHub Desktop.
KerasJavaPredictor
<dependencies>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>1.0.0-beta3</version>
</dependency>
</dependencies>
<dependencies>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-modelimport</artifactId>
<version>1.0.0-beta3</version>
</dependency>
</dependencies>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>1.0.0-beta3</version>
</dependency>
</dependencies>
package co.chatterbox.xai.image;
import co.chatterbox.image_ablation.ModelOutput;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.List;
public class KerasJavaOutput implements ModelOutput {
private INDArray predictions;
private List<String> classes;
KerasJavaOutput(List<String> classes, INDArray predictions) {
this.classes = classes;
this.predictions = predictions;
}
@Override
public double getScoreForClass(String classId) {
int pos = classes.indexOf(classId);
return predictions.getDouble(pos);
}
@Override
public String getPredictedClass() {
int maxAt = 0;
for (int i = 0; i < predictions.columns(); i++) { // search the maximum predicted score
maxAt = predictions.getDouble(i) > predictions.getDouble(maxAt) ? i : maxAt;
}
return classes.get(maxAt);
}
}
package co.chatterbox.xai.image;
import co.chatterbox.image_ablation.ModelOutput;
import co.chatterbox.image_ablation.PredictionInterface;
import co.chatterbox.xai.spi.ImagePredictor;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.transform.ColorConversionTransform;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.nd4j.linalg.api.ndarray.INDArray;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.stream.Collectors;
import static org.bytedeco.javacpp.opencv_imgproc.COLOR_BGR2RGB;
public class KerasJavaPredictor implements PredictionInterface {
private List<String> classes;
private ComputationGraph model;
// Hard coded params in Keras models
private int width = 224;
private int height = 224;
private int channels = 3;
private NativeImageLoader imgLoader = new NativeImageLoader(height, width, channels, new ColorConversionTransform(COLOR_BGR2RGB));
private INDArray preProcess(BufferedImage image) {
// convert to matrix
INDArray img;
try {
img = imgLoader.asMatrix(image);
} catch (IOException e) {
throw new RuntimeException(e);
}
return img;
}
@Override
public ModelOutput predict(BufferedImage bufferedImage) {
INDArray image = preProcess(bufferedImage);
return new KerasJavaOutput(classes, model.outputSingle(false, image));
}
@Override
public List<ModelOutput> predictBatch(List<BufferedImage> images) {
return images.stream().map(this::predict).collect(Collectors.toList());
}
public KerasJavaPredictor(List<String> classes, String modelPath) {
this.classes = classes;
try {
model = KerasModelImport.importKerasModelAndWeights(modelPath, false);
model.setCacheMode(CacheMode.DEVICE);
} catch (Exception e) {
throw new RuntimeException(e);
}
return this;
}
// TEST
public static void main(String[] args) {
KerasJavaPredictor predictor = new KerasJavaPredictor();
Map params = new HashMap<String, Object>();
params.put("model-path", "/.../cats_vs_dogs.hdf5");
params.put("classes", Arrays.asList("dogs", "cats"));
predictor = (KerasJavaPredictor) predictor.initialize(params);
List<BufferedImage> images = new ArrayList<>();
try {
images.add(ImageIO.read(new File("/.../dogs_vs_cats/dog.jpg")));
images.add(ImageIO.read(new File("/.../dogs_vs_cats/dog.jpg")));
images.add(ImageIO.read(new File("/.../dogs_vs_cats/dog.jpg")));
} catch (IOException e) {
e.printStackTrace();
}
List res = predictor.predictBatch(images);
System.out.println(res);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment