Skip to content

Instantly share code, notes, and snippets.

@Krucamper
Created November 23, 2019 18:12
Show Gist options
  • Save Krucamper/99c0f6ae15c75fec9ecb770acf14ecc6 to your computer and use it in GitHub Desktop.
Save Krucamper/99c0f6ae15c75fec9ecb770acf14ecc6 to your computer and use it in GitHub Desktop.
CNN กับ Marvel Cinematic Universe (test model)
from PIL import Image
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import model
import os
classes = {'BlackPanther', 'DoctorStrange', 'IronMan', 'ScarletWitch', 'SpiderMan', 'Thor'}
train_dir = 'raw'
logs_train_dir = './train_logs'
BATCH_SIZE = 1
CLASSES = 6
def get_random_dir(classes, train_dir):
class_len = len(classes)
random_class = np.random.randint(0, class_len)
class_name = classes[random_class]
path_random_dir = train_dir +'/' + class_name
return path_random_dir
def get_random_image(path_random_dir):
random_dir = os.listdir(path_random_dir)
len_random_dir = len(random_dir)
random_image = np.random.randint(0, len_random_dir)
image_name = random_dir[int(random_image)]
path_random_image = path_random_dir + '/' + image_name
return path_random_image
def get_one_image(image):
img = Image.open(image)
plt.imshow(img)
plt.show()
image = np.array(img)
return image
def evaluate_image(image_array):
with tf.Graph().as_default():
image = tf.cast(image_array, tf.float32)
image = tf.image.per_image_standardization(image)
image = tf.reshape(image, [1, 64, 64, 3])
logic = model.inference(image, BATCH_SIZE, CLASSES)
logic = tf.nn.softmax(logic)
x = tf.compabt.v1.placeholder(tf.float32, shape=[64, 64, 3])
saver = tf.compat.v1.train.Saver()
with tf.compat.v1.Session() as sess:
print("Reading checkpoints...")
ckpt = tf.train.get_checkpoint_state(logs_train_dir)
if ckpt and ckpt.model_checkpoint_path:
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
saver.restore(sess, ckpt.model_checkpoint_path)
print('Loading..., step is %s' % global_step)
else:
print('Error')
prediction = sess.run(logic, feed_dict={x: image_array})
max_index = np.argmax(prediction)
result = classes[max_index]
result = str(result) + str(prediction[:, int(max_index)])
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment