Skip to content

Instantly share code, notes, and snippets.

@groverpr
Created April 13, 2020 04:19
Show Gist options
  • Save groverpr/b709af7788d91ada507a7b59ca9b05f3 to your computer and use it in GitHub Desktop.
Save groverpr/b709af7788d91ada507a7b59ca9b05f3 to your computer and use it in GitHub Desktop.
def load_base_model(model_path, epoch, ctx, layer_name=None, n_inputs=2):
""" Loads the model from given model path
and returns a subnetwork that gives output from layer_name
"""
net = gluon.nn.SymbolBlock.imports(
model_path + "-symbol.json",
['data%i' % i for i in range(n_inputs)],
model_path + "-%.4d.params" % epoch,
ctx=ctx,
)
inputs = [mx.sym.var(('data%i')% i) for i in range(n_inputs)]
output = net(*inputs)
outputs = output.get_internals()[layer_name]
return gluon.SymbolBlock(outputs, inputs, params=net.collect_params())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment