Skip to content

Instantly share code, notes, and snippets.

@itsuwari
Created September 4, 2024 13:45
Show Gist options
  • Save itsuwari/4ff80c549084310aedadd0d7b02e6777 to your computer and use it in GitHub Desktop.
Save itsuwari/4ff80c549084310aedadd0d7b02e6777 to your computer and use it in GitHub Desktop.
ORB-MODELS to ONNX
import torch
from orb_models.forcefield import pretrained
from orb_models.forcefield.base import AtomGraphs
# 1. Load the Pretrained Model (choose one):
model = pretrained.orb_d3_v1() # Or another version like orb_d3_v1, etc.
# 2. Access the Core Model:
core_model = model.model # This is the MoleculeGNS instance
# 3. Construct a Dummy Input (AtomGraphs):
# Note how we assemble the data into a single AtomGraphs object
dummy_input = AtomGraphs(
senders=torch.tensor([0, 1, 2, 0]),
receivers=torch.tensor([1, 2, 0, 2]),
n_node=torch.tensor([3]), # 3 nodes in this example
n_edge=torch.tensor([4]), # 4 edges in this example
node_features={
"atomic_numbers": torch.tensor([1, 8, 1]), # Atomic numbers: H, O, H
"atomic_numbers_embedding": torch.nn.functional.one_hot(
torch.tensor([1, 8, 1]), num_classes=118
).type(torch.float32), # One-hot embeddings of atomic numbers
"positions": torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), # Positions of atoms
"feat": torch.randn(3, 256), # Initial node features - random for this example
},
edge_features={
"vectors": torch.tensor(
[[1.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, -1.0, 0.0], [0.0, 1.0, 0.0]]
), # Edge vectors calculated from positions
"feat": torch.randn(4, 53), # Initial edge features - random for this example
},
system_features={
"cell": torch.tensor([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 2.0]]).unsqueeze(
0
) # Unit cell information
},
)
# 4. Prepare for ONNX Export:
core_model.eval()
# 3. Wrapper for ONNX Export (Crucial!)
class ORBCoreWrapper(torch.nn.Module):
def __init__(self, core_model):
super().__init__()
self.core_model = core_model
def forward(
self,
senders,
receivers,
n_node,
n_edge,
atomic_numbers,
atomic_numbers_embedding,
positions,
node_feat,
edge_vectors,
edge_feat,
cell
):
# Reconstruct the AtomGraphs object from individual tensors
atom_graphs = AtomGraphs(
senders=senders,
receivers=receivers,
n_node=n_node,
n_edge=n_edge,
node_features={
"atomic_numbers": atomic_numbers,
"atomic_numbers_embedding": atomic_numbers_embedding,
"positions": positions,
"feat": node_feat,
},
edge_features={
"vectors": edge_vectors,
"feat": edge_feat,
},
system_features={
"cell": cell,
},
)
# Call the original core model
return self.core_model(atom_graphs)
# 4. Prepare for ONNX Export
wrapped_model = ORBCoreWrapper(core_model)
wrapped_model.eval()
# 5. Export to ONNX (Corrected)
input_names = ["senders", "receivers", "n_node", "n_edge",
"atomic_numbers", "atomic_numbers_embedding", "positions",
"node_feat", "edge_vectors", "edge_feat", "cell"]
output_names = ["updated_node_features", "updated_edge_features"]
torch.onnx.export(
wrapped_model,
(
dummy_input.senders,
dummy_input.receivers,
dummy_input.n_node,
dummy_input.n_edge,
dummy_input.node_features["atomic_numbers"],
dummy_input.node_features["atomic_numbers_embedding"],
dummy_input.node_features["positions"],
dummy_input.node_features["feat"],
dummy_input.edge_features["vectors"],
dummy_input.edge_features["feat"],
dummy_input.system_features["cell"],
),
"orb_d3_v1.onnx",
opset_version=16,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"senders": {0: "num_edges"},
"receivers": {0: "num_edges"},
"atomic_numbers": {0: "num_nodes"},
"atomic_numbers_embedding": {0: "num_nodes"},
"positions": {0: "num_nodes"},
"node_feat": {0: "num_nodes"},
"edge_vectors": {0: "num_edges"},
"edge_feat": {0: "num_edges"},
"updated_node_features": {0: "num_nodes"},
"updated_edge_features": {0: "num_edges"},
},
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment