Created
January 14, 2019 23:48
-
-
Save zhreshold/9c28c176908c4289dfa5572140ed2dd6 to your computer and use it in GitHub Desktop.
Demo script with gluon network -> symbol -> SymbolBlock
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""SSD Demo script.""" | |
import os | |
import argparse | |
import mxnet as mx | |
import gluoncv as gcv | |
from gluoncv.data.transforms import presets | |
from matplotlib import pyplot as plt | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Test with SSD networks.') | |
parser.add_argument('--network', type=str, default='ssd_300_vgg16_atrous_voc', | |
help="Base network name") | |
parser.add_argument('--images', type=str, default='', | |
help='Test images, use comma to split multiple.') | |
parser.add_argument('--gpus', type=str, default='', | |
help='Training with GPUs, you can specify 1,3 for example.') | |
parser.add_argument('--pretrained', type=str, default='True', | |
help='Load weights from previously saved parameters.') | |
args = parser.parse_args() | |
return args | |
if __name__ == '__main__': | |
args = parse_args() | |
# context list | |
ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()] | |
ctx = [mx.cpu()] if not ctx else ctx | |
# grab some image if not specified | |
if not args.images.strip(): | |
gcv.utils.download("https://cloud.githubusercontent.com/assets/3307514/" + | |
"20012568/cbc2d6f6-a27d-11e6-94c3-d35a9cb47609.jpg", 'street.jpg') | |
image_list = ['street.jpg'] | |
else: | |
image_list = [x.strip() for x in args.images.split(',') if x.strip()] | |
if args.pretrained.lower() in ['true', '1', 'yes', 't']: | |
net = gcv.model_zoo.get_model(args.network, pretrained=True) | |
else: | |
net = gcv.model_zoo.get_model(args.network, pretrained=False, pretrained_base=False) | |
net.load_parameters(args.pretrained) | |
net.set_nms(0.45, 200) | |
# export to json and load back with SymbolBlock | |
class_names = net.classes | |
print('Export to JSON...') | |
gcv.utils.export_block(args.network, net, preprocess=False, layout='CHW') | |
print('Load back from JSON with SymbolBlock') | |
net = mx.gluon.SymbolBlock.imports('{}-symbol.json'.format(args.network), | |
['data'], '{}-0000.params'.format(args.network)) | |
net.collect_params().reset_ctx(ctx = ctx) | |
ax = None | |
for image in image_list: | |
x, img = presets.ssd.load_test(image, short=512) | |
x = x.as_in_context(ctx[0]) | |
ids, scores, bboxes = [xx[0].asnumpy() for xx in net(x)] | |
ax = gcv.utils.viz.plot_bbox(img, bboxes, scores, ids, | |
class_names=class_names, ax=ax) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment