Skip to content

Instantly share code, notes, and snippets.

new_adj = torch.triu(adj, diagonal=1)
# print(new_adj[-1, :10, :10])
new_adj1 = torch.bmm(new_adj, new_adj)
# print(new_adj1[-1, :10, :10])
new_adj_or = torch.clamp((new_adj + new_adj1), max=1)
# print('new_adj_or', new_adj_or[-1, :10, :10])
loop = 1
while not torch.equal(torch.bmm(new_adj_or, new_adj_or), new_adj1):
new_adj1 = torch.bmm(new_adj_or, new_adj_or)
new_adj_or = torch.clamp((new_adj_or + new_adj1), max=1)