Skip to content

Instantly share code, notes, and snippets.

@sparticlesteve
Created September 3, 2020 18:37
Show Gist options
  • Save sparticlesteve/62854712aed7a7e46b70efaec0c64e4f to your computer and use it in GitHub Desktop.
Save sparticlesteve/62854712aed7a7e46b70efaec0c64e4f to your computer and use it in GitHub Desktop.
Modified graph conv LSTM example showing graph sequence data
import torch
import random
import numpy as np
import networkx as nx
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import GConvLSTM
def create_mock_data(number_of_nodes, edge_per_node, in_channels):
"""
Creating a mock feature matrix and edge index.
"""
graph = nx.watts_strogatz_graph(number_of_nodes, edge_per_node, 0.5)
edge_index = torch.LongTensor(np.array([edge for edge in graph.edges()]).T)
X = torch.FloatTensor(np.random.uniform(-1, 1, (number_of_nodes, in_channels)))
return X, edge_index
def create_mock_edge_weight(edge_index):
"""
Creating a mock edge weight tensor.
"""
return torch.FloatTensor(np.random.uniform(0, 1, (edge_index.shape[1])))
def create_mock_target(number_of_nodes, number_of_classes):
"""
Creating a mock target vector.
"""
return torch.LongTensor([random.randint(0, number_of_classes-1) for node in range(number_of_nodes)])
class RecurrentGCN(torch.nn.Module):
def __init__(self, node_features, num_classes):
super(RecurrentGCN, self).__init__()
# Documentation for GConvLSTM:
# https://pytorch-geometric-temporal.readthedocs.io/en/latest/modules/root.html#torch_geometric_temporal.nn.recurrent.gconv_lstm.GConvLSTM
self.recurrent_1 = GConvLSTM(node_features, 32, 5)
self.recurrent_2 = GConvLSTM(32, 16, 5)
self.linear = torch.nn.Linear(16, num_classes)
def forward(self, graphs):
# Process the sequence of graphs with our 2 GConvLSTM layers
# Initialize hidden and cell states to None so they are properly
# initialized automatically in the GConvLSTM layers.
h1, c1, h2, c2 = None, None, None, None
for x, edge_index, edge_weight in graphs:
h1, c1 = self.recurrent_1(x, edge_index, edge_weight, H=h1, C=c1)
# Feed hidden state output of first layer to the 2nd layer
h2, c2 = self.recurrent_2(h1, edge_index, edge_weight, H=h2, C=c2)
# Use the final hidden state output of 2nd recurrent layer for input to classifier
x = F.relu(h2)
x = F.dropout(x, training=self.training)
x = self.linear(x)
return F.log_softmax(x, dim=1)
node_features = 100
node_count = 1000
num_classes = 10
sequence_len = 4
edge_per_node = 15
epochs = 200
learning_rate = 0.01
weight_decay = 5e-4
print('Building model')
model = RecurrentGCN(node_features=node_features, num_classes=num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
model.train()
for epoch in range(epochs):
print('Step', epoch)
optimizer.zero_grad()
# Create a sequence of mock graphs
graphs = []
for i in range(sequence_len):
x, edge_index = create_mock_data(node_count, edge_per_node, node_features)
edge_weight = create_mock_edge_weight(edge_index)
graphs.append((x, edge_index, edge_weight))
# Create a mock target
target = create_mock_target(node_count, num_classes)
# Apply the model to the graph sequence
scores = model(graphs)
# Loss, optimizer
loss = F.nll_loss(scores, target)
loss.backward()
optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment