Skip to content

Instantly share code, notes, and snippets.

@petewarden
Created January 20, 2021 22:16
Show Gist options
  • Save petewarden/927b11914b905d10f50894453f3fbf7e to your computer and use it in GitHub Desktop.
Save petewarden/927b11914b905d10f50894453f3fbf7e to your computer and use it in GitHub Desktop.
lstm_quantization.py
# ==============================================================================
"""LSTM quantization with python."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pathlib
from absl import app
from absl import flags
import numpy as np
import tensorflow as tf
flags.DEFINE_integer('train_steps', 2, 'Number of steps in training.')
flags.DEFINE_string('tflite_dir', '/tmp/lstm/tflite',
'Directory to save/restore float tflite model.')
def load_data(training_data_points, test_data_points):
"""Load mnist data, down sample and transform."""
tf.print('Loading data...\n')
# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Down sampling.
train_images = train_images[0:training_data_points]
train_labels = train_labels[0:training_data_points]
test_images = test_images[0:test_data_points]
test_labels = test_labels[0:test_data_points]
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images.astype(np.float32) / 255.0
test_images = test_images.astype(np.float32) / 255.0
# Retrun.
return (train_images, train_labels, test_images, test_labels)
def train(model, train_images, train_labels, test_images, test_labels, steps):
"""Train the model."""
tf.print('Training model...\n')
# Default batch is 32 so 960 runs 30 iterations.
model.fit(
train_images,
train_labels,
epochs=steps,
validation_data=(test_images, test_labels))
def build_model():
"""Build LSTM model."""
tf.print('Building LSTM model.\n')
model = tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(28, 28), name='input'),
tf.keras.layers.LSTM(20, time_major=False, return_sequences=True),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10, activation=tf.nn.softmax, name='output')
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.summary()
return model
def generate_data():
"""Generator for calibation data."""
tf.print('Generating calibration data...\n')
mnist = tf.keras.datasets.mnist
(images, _), (_, _) = mnist.load_data()
images = images[0:64]
images = images.astype(np.float32) / 255.0
for image in images:
# Resize. [28, 28] to [1, 28, 28] for tflite.
image = np.expand_dims(image, axis=0)
yield [image]
def convert_and_quantize_model(model):
"""Convert and Quantize LSTM model."""
tf.print('Quatizing fused LSTM model...\n')
run_model = tf.function(lambda x: model(x))
# Resize input.
batch_size = 1
steps = 28
input_size = 28
concrete_func = run_model.get_concrete_function(
tf.TensorSpec([batch_size, steps, input_size], model.inputs[0].dtype))
# Save to model directory.
model_dir = '/tmp/keras_lstm'
model.save(model_dir, save_format='tf', signatures=concrete_func)
# Quantize from saved model.
converter = tf.lite.TFLiteConverter.from_saved_model(model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = generate_data
tflite_model = converter.convert()
return tflite_model
def tflite_float_inference(model, images, expected_labels):
"""Run the model."""
tf.print('Running tflite_float_inference...\n')
# Initialize the interpreter
interpreter = tf.lite.Interpreter(model_content=model)
interpreter.allocate_tensors()
# Interpreter details.
input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]
print('input_details:')
print(input_details)
print('output_details:')
print(output_details)
# Expand dimension.
image = np.expand_dims(images[0], axis=0).astype(input_details['dtype'])
expected_label = expected_labels[0]
interpreter.set_tensor(input_details['index'], image)
interpreter.invoke()
output = interpreter.get_tensor(output_details['index'])[0] # [0] for batch.
prediction = output.argmax()
print('output:')
print(interpreter.get_tensor(output_details['index']))
print('Expected', expected_label, ' and predicted', prediction, '\n')
def save_tflite_files(model, path, name):
"""Save the TFLite model."""
tf.print('Saving tflite model', path, '/', name, '...\n')
tflite_models_dir = pathlib.Path(path)
tflite_models_dir.mkdir(exist_ok=True, parents=True)
tflite_model_file = tflite_models_dir / name
tflite_model_file.write_bytes(model)
def main(_):
# Load data.
(train_images, train_labels, test_images, test_labels) = load_data(960, 320)
# Build model.
model = build_model()
# Train model.
train(model, train_images, train_labels, test_images, test_labels,
flags.FLAGS.train_steps)
# Convert model.
tflite_model_quantized_fused = convert_and_quantize_model(model)
# Run tflite inference.
tflite_float_inference(tflite_model_quantized_fused, test_images, test_labels)
# Save tflite model.
location = flags.FLAGS.tflite_dir
save_tflite_files(tflite_model_quantized_fused, location, 'lstm_quant.tflite')
if __name__ == '__main__':
app.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment