Last active
January 12, 2017 18:19
-
-
Save bigsnarfdude/6f9e1296862f9190eb9045c949441a9b to your computer and use it in GitHub Desktop.
flask classification poc service
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0rc0-cp34-cp34m-linux_x86_64.whl | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import shutil | |
from sklearn import datasets, metrics, cross_validation | |
from tensorflow.contrib import skflow | |
import numpy as np | |
from flask import Flask, abort, jsonify, request | |
#import cPickle as pickle | |
# basic toy nn pickled loader | |
#pkl_file = open('batch_model_2016_04_23.pkl', 'rb') | |
#latest_neural_network = pickle.load(pkl_file) | |
#pkl_file.close() | |
from sklearn import datasets, metrics, cross_validation | |
from tensorflow.contrib import skflow | |
iris = datasets.load_iris() | |
X_train, X_test, y_train, y_test = cross_validation.train_test_split(iris.data, iris.target, | |
test_size=0.2, random_state=42) | |
# trained batch offline saved skflow model (latest parameters and learned variables) | |
#classifier.save('/home/ubuntu/scratch/skflow_batch/batch_model_2016_04_23') | |
# restore skflow model from batch run 2016-04-23 | |
new_classifier = skflow.TensorFlowEstimator.restore('/home/ubuntu/scratch/skflow_batch/batch_model_2016_04_23') | |
# check model load with test data | |
score = metrics.accuracy_score(y_test, new_classifier.predict(X_test)) | |
print('Accuracy: {0:f}'.format(score)) | |
app = Flask(__name__) | |
@app.route('/api', methods=['POST']) | |
def make_predict(): | |
# incoming data converted from json | |
data = request.get_json(force=True) | |
# shove into array | |
predict_request = [data['sl'],data['sw'],data['pl'], data['pw']] | |
predict_request = np.array([predict_request]) | |
# np array passed to toy neural network | |
# https://gist.github.com/bigsnarfdude/57ff7d6095f7ee83d4195d1fed26388b | |
y_hat = new_classifier.predict(predict_request) | |
output = str(y_hat[0]) | |
# convert output to json | |
return jsonify(results=output) | |
if __name__ == '__main__': | |
app.run(port = 11111, debug = True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment