Skip to content

Instantly share code, notes, and snippets.

@jackmead515
Created August 18, 2024 02:17
Show Gist options
  • Save jackmead515/802dfe49489b63927beb9f373ab494b0 to your computer and use it in GitHub Desktop.
Save jackmead515/802dfe49489b63927beb9f373ab494b0 to your computer and use it in GitHub Desktop.
Convert To Detectron2 Compatability
import torch
import re
phmodel = torch.load("phenobench_mask_rcnn_r50_fpn.pt", map_location=torch.device('cpu'))
oldmodel = phmodel['model_state_dict']
newmodel = {}
for k in list(oldmodel.keys()):
old_k = k
if "layer" not in k:
k = "stem." + k
for t in [1, 2, 3, 4]:
k = k.replace("layer{}".format(t), "res{}".format(t + 1))
for t in [1, 2, 3]:
k = k.replace("bn{}".format(t), "conv{}.norm".format(t))
k = k.replace("downsample.0", "shortcut")
k = k.replace("downsample.1", "shortcut.norm")
newmodel[k] = oldmodel.pop(old_k).detach().numpy()
# replace keys with stem.network or network at the beginning with ''
for key in list(newmodel.keys()):
new_key = re.sub(r'^stem\.network\.', '', key)
new_key = re.sub(r'^network\.', '', new_key)
newmodel[new_key] = newmodel.pop(key)
# rename the 'stem' layers
for key in list(newmodel.keys()):
if key.startswith('backbone.body.conv1'):
new_key = key.replace('backbone.body.conv1', 'backbone.bottom_up.stem.conv1')
newmodel[new_key] = newmodel.pop(key)
# rename the res layers
for key in list(newmodel.keys()):
if key.startswith('backbone.body.res'):
new_key = key.replace('backbone.body.res', 'backbone.bottom_up.res')
newmodel[new_key] = newmodel.pop(key)
originals = """
backbone.fpn.inner_blocks.0.0.bias
backbone.fpn.inner_blocks.0.0.weight
backbone.fpn.inner_blocks.1.0.bias
backbone.fpn.inner_blocks.1.0.weight
backbone.fpn.inner_blocks.2.0.bias
backbone.fpn.inner_blocks.2.0.weight
backbone.fpn.inner_blocks.3.0.bias
backbone.fpn.inner_blocks.3.0.weight
backbone.fpn.layer_blocks.0.0.bias
backbone.fpn.layer_blocks.0.0.weight
backbone.fpn.layer_blocks.1.0.bias
backbone.fpn.layer_blocks.1.0.weight
backbone.fpn.layer_blocks.2.0.bias
backbone.fpn.layer_blocks.2.0.weight
backbone.fpn.layer_blocks.3.0.bias
backbone.fpn.layer_blocks.3.0.weight
roi_heads.box_head.fc6.bias
roi_heads.box_head.fc6.weight
roi_heads.box_head.fc7.bias
roi_heads.box_head.fc7.weight
roi_heads.box_predictor.bbox_pred.bias
roi_heads.box_predictor.bbox_pred.weight
roi_heads.box_predictor.cls_score.bias
roi_heads.box_predictor.cls_score.weight
roi_heads.mask_head.0.0.bias
roi_heads.mask_head.0.0.weight
roi_heads.mask_head.1.0.bias
roi_heads.mask_head.1.0.weight
roi_heads.mask_head.2.0.bias
roi_heads.mask_head.2.0.weight
roi_heads.mask_head.3.0.bias
roi_heads.mask_head.3.0.weight
roi_heads.mask_predictor.conv5_mask.bias
roi_heads.mask_predictor.conv5_mask.weight
roi_heads.mask_predictor.mask_fcn_logits.bias
roi_heads.mask_predictor.mask_fcn_logits.weight
""".split('\n')[1:-1]
replacements = """
backbone.fpn_lateral2.bias
backbone.fpn_lateral2.weight
backbone.fpn_lateral3.bias
backbone.fpn_lateral3.weight
backbone.fpn_lateral4.bias
backbone.fpn_lateral4.weight
backbone.fpn_lateral5.bias
backbone.fpn_lateral5.weight
backbone.fpn_output2.bias
backbone.fpn_output2.weight
backbone.fpn_output3.bias
backbone.fpn_output3.weight
backbone.fpn_output4.bias
backbone.fpn_output4.weight
backbone.fpn_output5.bias
backbone.fpn_output5.weight
roi_heads.box_head.fc1.bias
roi_heads.box_head.fc1.weight
roi_heads.box_head.fc2.bias
roi_heads.box_head.fc2.weight
roi_heads.box_predictor.bbox_pred.bias
roi_heads.box_predictor.bbox_pred.weight
roi_heads.box_predictor.cls_score.bias
roi_heads.box_predictor.cls_score.weight
roi_heads.mask_head.mask_fcn1.bias
roi_heads.mask_head.mask_fcn1.weight
roi_heads.mask_head.mask_fcn2.bias
roi_heads.mask_head.mask_fcn2.weight
roi_heads.mask_head.mask_fcn3.bias
roi_heads.mask_head.mask_fcn3.weight
roi_heads.mask_head.mask_fcn4.bias
roi_heads.mask_head.mask_fcn4.weight
roi_heads.mask_head.deconv.bias
roi_heads.mask_head.deconv.weight
roi_heads.mask_head.predictor.bias
roi_heads.mask_head.predictor.weight
""".split('\n')[1:-1]
for o, r in zip(originals, replacements):
newmodel[r] = newmodel.pop(o)
torch.save({ 'model': newmodel }, "phenobench_mask_rcnn_r50_fpn_fixed.pth")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment