Created
March 15, 2020 02:00
-
-
Save chnsh/1c39b8733288260ac811f9a027d9a1c3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn import Parameter, functional as F | |
from torch_geometric.nn import MessagePassing | |
from torch_geometric.nn.inits import glorot, zeros | |
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax | |
from torch_geometric.utils.num_nodes import maybe_num_nodes | |
from torch_scatter import scatter_max, scatter_add | |
def softmax(src, index, num_nodes=None): | |
r"""Computes a sparsely evaluated softmax. | |
Given a value tensor :attr:`src`, this function first groups the values | |
along the first dimension based on the indices specified in :attr:`index`, | |
and then proceeds to compute the softmax individually for each group. | |
Args: | |
src (Tensor): The source tensor. | |
index (LongTensor): The indices of elements for applying the softmax. | |
num_nodes (int, optional): The number of nodes, *i.e.* | |
:obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) | |
:rtype: :class:`Tensor` | |
""" | |
num_nodes = maybe_num_nodes(index, num_nodes) | |
out = src - scatter_max(src, index, dim=1, dim_size=num_nodes)[0][:, index] | |
out = out.exp() | |
out = out / ( | |
scatter_add(out, index, dim=1, dim_size=num_nodes)[:, index] + 1e-16) | |
return out | |
class GATConv(MessagePassing): | |
r"""The graph attentional operator from the `"Graph Attention Networks" | |
<https://arxiv.org/abs/1710.10903>`_ paper | |
.. math:: | |
\mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + | |
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, | |
where the attention coefficients :math:`\alpha_{i,j}` are computed as | |
.. math:: | |
\alpha_{i,j} = | |
\frac{ | |
\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} | |
[\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] | |
\right)\right)} | |
{\sum_{k \in \mathcal{N}(i) \cup \{ i \}} | |
\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} | |
[\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] | |
\right)\right)}. | |
Args: | |
in_channels (int): Size of each input sample. | |
out_channels (int): Size of each output sample. | |
heads (int, optional): Number of multi-head-attentions. | |
(default: :obj:`1`) | |
concat (bool, optional): If set to :obj:`False`, the multi-head | |
attentions are averaged instead of concatenated. | |
(default: :obj:`True`) | |
negative_slope (float, optional): LeakyReLU angle of the negative | |
slope. (default: :obj:`0.2`) | |
dropout (float, optional): Dropout probability of the normalized | |
attention coefficients which exposes each node to a stochastically | |
sampled neighborhood during training. (default: :obj:`0`) | |
bias (bool, optional): If set to :obj:`False`, the layer will not learn | |
an additive bias. (default: :obj:`True`) | |
**kwargs (optional): Additional arguments of | |
:class:`torch_geometric.nn.conv.MessagePassing`. | |
""" | |
def __init__(self, in_channels, out_channels, heads=1, concat=True, negative_slope=0.2, dropout=0, bias=True, | |
**kwargs): | |
super().__init__(aggr='add', **kwargs) | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.heads = heads | |
self.concat = concat | |
self.negative_slope = negative_slope | |
self.dropout = dropout | |
self.weight = Parameter( | |
torch.Tensor(in_channels, heads * out_channels)) | |
self.att = Parameter(torch.Tensor(1, heads, 2 * out_channels)) | |
if bias and concat: | |
self.bias = Parameter(torch.Tensor(heads * out_channels)) | |
elif bias and not concat: | |
self.bias = Parameter(torch.Tensor(out_channels)) | |
else: | |
self.register_parameter('bias', None) | |
self.reset_parameters() | |
def reset_parameters(self): | |
glorot(self.weight) | |
glorot(self.att) | |
zeros(self.bias) | |
def forward(self, x, edge_index, size=None): | |
"""""" | |
if size is None and torch.is_tensor(x): | |
edge_index, _ = remove_self_loops(edge_index) | |
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(self.node_dim)) | |
if torch.is_tensor(x): | |
x = torch.matmul(x, self.weight) | |
else: | |
x = (None if x[0] is None else torch.matmul(x[0], self.weight), | |
None if x[1] is None else torch.matmul(x[1], self.weight)) | |
return self.propagate(edge_index, size=size, x=x) | |
def message(self, edge_index_i, x_i, x_j, size_i): | |
""" | |
:param edge_index_i: shape: (E,) | |
:param x_i: (b, E, heads * out_channels) | |
:param x_j: (b, E, heads * out_channels) | |
:param size_i: | |
:return: (b, E, heads, out_channels) | |
""" | |
# Reshape x_i and x_j | |
x_j = x_j.view(-1, edge_index_i.shape[0], self.heads, self.out_channels) | |
x_i = x_i.view(-1, edge_index_i.shape[0], self.heads, self.out_channels) | |
# cat x_i and x_j on the last dimension and multiply with shared attention function | |
alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1) | |
alpha = F.leaky_relu(alpha, self.negative_slope) | |
alpha = softmax(alpha, edge_index_i, size_i) | |
# Sample attention coefficients stochastically. | |
alpha = F.dropout(alpha, p=self.dropout, training=self.training) | |
return x_j * torch.unsqueeze(alpha, dim=-1) | |
def update(self, aggr_out): | |
b, num_nodes, num_heads, out_channels = aggr_out.size() | |
# apply elu activation function according to GAT paper | |
aggr_out = F.elu(aggr_out) | |
aggr_out = F.dropout(aggr_out, p=self.dropout, training=self.training) | |
if self.concat is True: | |
aggr_out = aggr_out.view(b, num_nodes, self.heads * self.out_channels) | |
else: | |
aggr_out = aggr_out.mean(dim=2) | |
if self.bias is not None: | |
aggr_out = aggr_out + self.bias | |
return aggr_out | |
def __repr__(self): | |
return '{}({}, {}, heads={})'.format(self.__class__.__name__, | |
self.in_channels, | |
self.out_channels, self.heads) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment