# -*- coding: utf-8 -*-
#
# Copyright 2018-2019 Data61, CSIRO
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Definition of Graph Attention Network (GAT) layer, and GAT class that is a stack of GAT layers
"""
__all__ = ["GraphAttention", "GAT"]
from keras import activations, constraints, initializers, regularizers
from keras import backend as K
from keras.layers import Input, Layer, Dropout, LeakyReLU, Lambda, Reshape
import numpy as np
import tensorflow as tf
from stellargraph.mapper import FullBatchNodeGenerator
import warnings
warnings.simplefilter("default")
[docs]class GraphAttention(Layer):
"""
Graph Attention (GAT) layer, base implementation taken from https://github.com/danielegrattarola/keras-gat,
some modifications added for ease of use.
Based on the original paper: Graph Attention Networks. P. Velickovic et al. ICLR 2018 https://arxiv.org/abs/1803.07294
Args:
F_out (int): dimensionality of output feature vectors
attn_heads (int or list of int): number of attention heads
attn_heads_reduction (str): reduction applied to output features of each attention head, 'concat' or 'average'.
'Average' should be applied in the final prediction layer of the model (Eq. 6 of the paper).
in_dropout_rate (float): dropout rate applied to features
attn_dropout_rate (float): dropout rate applied to attention coefficients
activation (str): nonlinear activation applied to layer's output to obtain output features (eq. 4 of the GAT paper)
use_bias (bool): toggles an optional bias
kernel_initializer (str): name of layer bias f the initializer for kernel parameters (weights)
bias_initializer (str): name of the initializer for bias
attn_kernel_initializer (str): name of the initializer for attention kernel
kernel_regularizer (str): name of regularizer to be applied to layer kernel. Must be a Keras regularizer.
bias_regularizer (str): name of regularizer to be applied to layer bias. Must be a Keras regularizer.
attn_kernel_regularizer (str): name of regularizer to be applied to attention kernel. Must be a Keras regularizer.
activity_regularizer (str): not used in the current implementation
kernel_constraint (str): constraint applied to layer's kernel. Must be a Keras constraint https://keras.io/constraints/
bias_constraint (str): constraint applied to layer's bias. Must be a Keras constraint https://keras.io/constraints/
attn_kernel_constraint (str): constraint applied to attention kernel. Must be a Keras constraint https://keras.io/constraints/
**kwargs: optional keyword arguments
"""
def __init__(
self,
F_out,
attn_heads=1,
attn_heads_reduction="concat", # {'concat', 'average'}
in_dropout_rate=0.0,
attn_dropout_rate=0.0,
activation="relu",
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
attn_kernel_initializer="glorot_uniform",
kernel_regularizer=None,
bias_regularizer=None,
attn_kernel_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
attn_kernel_constraint=None,
**kwargs
):
if attn_heads_reduction not in {"concat", "average"}:
raise ValueError(
"{}: Possible heads reduction methods: concat, average; received {}".format(
type(self).__name__, attn_heads_reduction
)
)
self.F_out = F_out # Number of output features (F' in the paper)
self.attn_heads = attn_heads # Number of attention heads (K in the paper)
self.attn_heads_reduction = attn_heads_reduction # Eq. 5 and 6 in the paper
self.in_dropout_rate = in_dropout_rate # dropout rate for node features
self.attn_dropout_rate = attn_dropout_rate # dropout rate for attention coefs
self.activation = activations.get(activation) # Eq. 4 in the paper
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.attn_kernel_initializer = initializers.get(attn_kernel_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.attn_kernel_regularizer = regularizers.get(attn_kernel_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.attn_kernel_constraint = constraints.get(attn_kernel_constraint)
# Populated by build()
self.kernels = [] # Layer kernels for attention heads
self.biases = [] # Layer biases for attention heads
self.attn_kernels = [] # Attention kernels for attention heads
if attn_heads_reduction == "concat":
# Output will have shape (..., K * F')
self.output_dim = self.F_out * self.attn_heads
else:
# Output will have shape (..., F')
self.output_dim = self.F_out
super(GraphAttention, self).__init__(**kwargs)
[docs] def get_config(self):
"""
Gets class configuration for Keras serialization
"""
config = {
"F_out": self.F_out,
"attn_heads": self.attn_heads,
"attn_heads_reduction": self.attn_heads_reduction,
"in_dropout_rate": self.in_dropout_rate,
"attn_dropout_rate": self.attn_dropout_rate,
"activation": activations.serialize(self.activation),
"use_bias": self.use_bias,
"kernel_initializer": initializers.serialize(self.kernel_initializer),
"bias_initializer": initializers.serialize(self.bias_initializer),
"attn_kernel_initializer": initializers.serialize(
self.attn_kernel_initializer
),
"kernel_regularizer": regularizers.serialize(self.kernel_regularizer),
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
"attn_kernel_regularizer": regularizers.serialize(
self.attn_kernel_regularizer
),
"kernel_constraint": constraints.serialize(self.kernel_constraint),
"bias_constraint": constraints.serialize(self.bias_constraint),
"attn_kernel_constraint": constraints.serialize(
self.attn_kernel_constraint
),
}
base_config = super().get_config()
return {**base_config, **config}
[docs] def build(self, input_shape):
"""
Builds the layer
Args:
input_shape (list of list of int): shapes of the layer's input(s)
"""
assert len(input_shape) >= 2
F_in = int(input_shape[0][-1])
# Initialize weights for each attention head
for head in range(self.attn_heads):
# Layer kernel
kernel = self.add_weight(
shape=(F_in, self.F_out),
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
name="kernel_{}".format(head),
)
self.kernels.append(kernel)
# # Layer bias
if self.use_bias:
bias = self.add_weight(
shape=(self.F_out,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
name="bias_{}".format(head),
)
self.biases.append(bias)
# Attention kernels
attn_kernel_self = self.add_weight(
shape=(self.F_out, 1),
initializer=self.attn_kernel_initializer,
regularizer=self.attn_kernel_regularizer,
constraint=self.attn_kernel_constraint,
name="attn_kernel_self_{}".format(head),
)
attn_kernel_neighs = self.add_weight(
shape=(self.F_out, 1),
initializer=self.attn_kernel_initializer,
regularizer=self.attn_kernel_regularizer,
constraint=self.attn_kernel_constraint,
name="attn_kernel_neigh_{}".format(head),
)
self.attn_kernels.append([attn_kernel_self, attn_kernel_neighs])
self.built = True
[docs] def call(self, inputs, **kwargs):
"""
Applies the layer.
Args:
inputs (list): list of inputs with 2 items: node features (matrix of size N x F),
and graph adjacency matrix (size N x N), where N is the number of nodes in the graph,
F is the dimensionality of node features
"""
X = inputs[0] # Node features (N x F)
A = inputs[1] # Adjacency matrix (N x N)
# Convert A to dense tensor - needed for the mask to work
# TODO: replace this dense implementation of GraphAttention layer with a sparse implementation
if K.is_sparse(A):
A = tf.sparse_tensor_to_dense(A, validate_indices=False)
# For the GAT model to match that in the paper, we need to ensure that the graph has self-loops,
# since the neighbourhood of node i in eq. (4) includes node i itself.
# Adding self-loops to A via setting the diagonal elements of A to 1.0:
if kwargs.get("add_self_loops", False):
# get the number of nodes from inputs[1] directly
N = K.int_shape(inputs[1])[-1]
if N is not None:
# create self-loops
A = tf.linalg.set_diag(A, K.cast(np.ones((N,)), dtype="float"))
else:
raise ValueError(
"{}: need to know number of nodes to add self-loops; obtained None instead".format(
type(self).__name__
)
)
outputs = []
for head in range(self.attn_heads):
kernel = self.kernels[head] # W in the paper (F x F')
attention_kernel = self.attn_kernels[
head
] # Attention kernel a in the paper (2F' x 1)
# Compute inputs to attention network
features = K.dot(X, kernel) # (N x F')
# Compute feature combinations
# Note: [[a_1], [a_2]]^T [[Wh_i], [Wh_2]] = [a_1]^T [Wh_i] + [a_2]^T [Wh_j]
attn_for_self = K.dot(
features, attention_kernel[0]
) # (N x 1), [a_1]^T [Wh_i]
attn_for_neighs = K.dot(
features, attention_kernel[1]
) # (N x 1), [a_2]^T [Wh_j]
# Attention head a(Wh_i, Wh_j) = a^T [[Wh_i], [Wh_j]]
dense = attn_for_self + K.transpose(
attn_for_neighs
) # (N x N) via broadcasting
# Add nonlinearity
dense = LeakyReLU(alpha=0.2)(dense)
# Mask values before activation (Vaswani et al., 2017)
# YT: this only works for 'binary' A, not for 'weighted' A!
# YT: if A does not have self-loops, the node itself will be masked, so A should have self-loops
# YT: this is ensured by setting the diagonal elements of A tensor to 1 above
mask = -10e9 * (1.0 - A)
dense += mask
# Apply softmax to get attention coefficients
dense = K.softmax(dense) # (N x N), Eq. 3 of the paper
# Apply dropout to features and attention coefficients
dropout_feat = Dropout(self.in_dropout_rate)(features) # (N x F')
dropout_attn = Dropout(self.attn_dropout_rate)(dense) # (N x N)
# Linear combination with neighbors' features [YT: see Eq. 4]
node_features = K.dot(dropout_attn, dropout_feat) # (N x F')
if self.use_bias:
node_features = K.bias_add(node_features, self.biases[head])
# Add output of attention head to final output
outputs.append(node_features)
# Aggregate the heads' output according to the reduction method
if self.attn_heads_reduction == "concat":
output = K.concatenate(outputs) # (N x KF')
else:
output = K.mean(K.stack(outputs), axis=0) # N x F')
output = self.activation(output)
return output
[docs] def compute_output_shape(self, input_shape):
output_shape = input_shape[0][0], self.output_dim
return output_shape
[docs]class GAT:
"""
A stack of Graph Attention (GAT) layers with aggregation of multiple attention heads, Eqs 5-6 of the GAT paper https://arxiv.org/abs/1803.07294
Args:
layer_sizes (list of int): list of output sizes of GAT layers in the stack. The length of this list defines
the number of GraphAttention layers in the stack.
attn_heads (int or list of int): number of attention heads in GraphAttention layers. The options are:
- a single integer: the passed value of `attn_heads` will be applied to all GraphAttention layers in the stack, except the last layer (for which the number of attn_heads will be set to 1).
- a list of integers: elements of the list define the number of attention heads in the corresponding layers in the stack.
attn_heads_reduction (list of str or None): reductions applied to output features of each attention head,
for all layers in the stack. Valid entries in the list are {'concat', 'average'}.
If None is passed, the default reductions are applied: 'concat' reduction to all layers in the stack
except the final layer, 'average' reduction to the last layer (Eqs. 5-6 of the GAT paper).
activations (list of str): list of activations applied to each layer's output
bias (bool): toggles an optional bias in GAT layers
in_dropout (float): dropout rate applied to input features of each GAT layer
attn_dropout (float): dropout rate applied to attention maps
normalize (str or None): normalization applied to the final output features of the GAT layers stack
generator (FullBatchNodeGenerator): an instance of FullBatchNodeGenerator class constructed on the graph of interest
"""
def __init__(
self,
layer_sizes,
activations,
attn_heads=1,
attn_heads_reduction=None,
bias=True,
in_dropout=0.0,
attn_dropout=0.0,
normalize="l2",
generator=None,
):
self._gat_layer = GraphAttention
self.bias = bias
self.in_dropout = in_dropout
self.attn_dropout = attn_dropout
self.generator = generator
# Check layer_sizes (must be list of int):
# check type:
if not isinstance(layer_sizes, list):
raise TypeError(
"{}: layer_sizes should be a list of integers; received type {} instead.".format(
type(self).__name__, type(layer_sizes).__name__
)
)
# check that values are valid:
elif not all([isinstance(s, int) and s > 0 for s in layer_sizes]):
raise ValueError(
"{}: all elements in layer_sizes should be positive integers!".format(
type(self).__name__
)
)
# Check attn_heads (must be int or list of int):
if isinstance(attn_heads, list):
# check the length
if not len(attn_heads) == len(layer_sizes):
raise ValueError(
"{}: length of attn_heads list ({}) should match the number of GAT layers ({})".format(
type(self).__name__, len(attn_heads), len(layer_sizes)
)
)
# check that values in the list are valid
if not all([isinstance(a, int) and a > 0 for a in attn_heads]):
raise ValueError(
"{}: all elements in attn_heads should be positive integers!".format(
type(self).__name__
)
)
self.attn_heads = attn_heads # (list of int as passed by the user)
elif isinstance(attn_heads, int):
self.attn_heads = list()
for l, _ in enumerate(layer_sizes):
# number of attention heads for layer l: attn_heads (int) for all but the last layer (for which it's set to 1)
self.attn_heads.append(attn_heads if l < len(layer_sizes) - 1 else 1)
else:
raise TypeError(
"{}: attn_heads should be an integer or a list of integers!".format(
type(self).__name__
)
)
# Check attn_heads_reduction (list of str, or None):
if attn_heads_reduction is None:
# set default head reductions, see eqs 5-6 of the GAT paper
self.attn_heads_reduction = ["concat"] * (len(layer_sizes) - 1) + [
"average"
]
else:
# user-specified list of head reductions (valid entries are 'concat' and 'average')
# check type (must be a list of str):
if not isinstance(attn_heads_reduction, list):
raise TypeError(
"{}: attn_heads_reduction should be a string; received type {} instead.".format(
type(self).__name__, type(attn_heads_reduction).__name__
)
)
# check length of attn_heads_reduction list:
if not len(attn_heads_reduction) == len(layer_sizes):
raise ValueError(
"{}: length of attn_heads_reduction list ({}) should match the number of GAT layers ({})".format(
type(self).__name__, len(attn_heads_reduction), len(layer_sizes)
)
)
# check that list elements are valid:
if all(
[ahr.lower() in {"concat", "average"} for ahr in attn_heads_reduction]
):
self.attn_heads_reduction = attn_heads_reduction
else:
raise ValueError(
"{}: elements of attn_heads_reduction list should be either 'concat' or 'average'!".format(
type(self).__name__
)
)
# Check activations (list of str):
# check type:
if not isinstance(activations, list):
raise TypeError(
"{}: activations should be a list of strings; received {} instead".format(
type(self).__name__, type(activations)
)
)
# check length:
if not len(activations) == len(layer_sizes):
raise ValueError(
"{}: length of activations list ({}) should match the number of GAT layers ({})".format(
type(self).__name__, len(activations), len(layer_sizes)
)
)
self.activations = activations
# check generator:
if generator is not None:
if not isinstance(generator, FullBatchNodeGenerator):
raise ValueError(
"{}: generator must be of type FullBatchNodeGenerator or None; received object of type {} instead".format(
type(self).__name__, type(generator).__name__
)
)
# Set the normalization layer used in the model
if normalize == "l2":
self._normalization = Lambda(lambda x: K.l2_normalize(x, axis=1))
elif normalize is None or str(normalize).lower() in {"none", "linear"}:
self._normalization = Lambda(lambda x: x)
else:
raise ValueError(
"Normalization should be either 'l2' or None (also allowed as 'none'); received '{}'".format(
normalize
)
)
# Initialize a stack of GAT layers
self._layers = []
for l, F_out in enumerate(layer_sizes):
# Dropout on input node features before each GAT layer
self._layers.append(Dropout(self.in_dropout))
# GraphAttention layer
self._layers.append(
self._gat_layer(
F_out=F_out,
attn_heads=self.attn_heads[l],
attn_heads_reduction=self.attn_heads_reduction[l],
in_dropout_rate=self.in_dropout,
attn_dropout_rate=self.attn_dropout,
activation=self.activations[l],
use_bias=self.bias,
)
)
def __call__(self, x_inp, **kwargs):
"""
Apply a stack of GAT layers to the input x_inp
Args:
x_inp (Tensor): input of the 1st GAT layer in the stack
Returns: Output tensor of the GAT layers stack
"""
assert isinstance(x_inp, list), "input must be a list, got {} instead".format(
type(x_inp)
)
x = x_inp[0]
A = x_inp[1]
for layer in self._layers:
if isinstance(layer, self._gat_layer): # layer is a GAT layer
x = layer([x, A], add_self_loops=kwargs.get("add_self_loops"))
else: # layer is a Dropout layer
x = layer(x)
return self._normalization(x)
[docs] def node_model(self, num_nodes=None, feature_size=None, add_self_loops=True):
"""
Builds a GAT model for node prediction
Args:
num_nodes (int or None): (optional) number of nodes in the graph (in the full batch). If not provided, this will be taken from self.generator.
feature_size (int or None): (optional) dimensionality of node attributes. If not provided, this will be taken from self.generator.
add_self_loops (bool): (default is True) toggles adding self-loops to the graph's adjacency matrix in the GraphAttention layers of the GAT model.
Returns:
tuple: `(x_inp, x_out)`, where `x_inp` is a list of two Keras input tensors for the specified GAT model
(containing node features and graph adjacency matrix), and `x_out` is a Keras tensor for the GAT model output.
"""
# Create input tensor:
if self.generator is not None:
N = self.generator.Aadj.shape[0]
assert self.generator.features.shape[0] == N
F = self.generator.features.shape[1]
is_adj_sparse = self.generator.sparse
elif num_nodes is not None and feature_size is not None:
N = num_nodes
F = feature_size
is_adj_sparse = True
else:
raise RuntimeError(
"node_model: if generator is not provided to object constructor, num_nodes and feature_size must be specified."
)
X_in = Input(shape=(F,))
A_in = Input(shape=(N,), sparse=is_adj_sparse)
x_inp = [X_in, A_in]
# Output from GAT model, N x F', where F' is the output size of the last GAT layer in the stack
x_out = self(x_inp, add_self_loops=add_self_loops)
return x_inp, x_out
[docs] def link_model(self):
"""
Builds a GAT model for link (node pair) prediction (implementation pending)
"""
raise NotImplemented
def default_model(self, flatten_output=False):
warnings.warn(
"The .default_model() method will be deprecated soon. "
"Please use .node_model() or .link_model() methods instead.",
PendingDeprecationWarning,
)
return self.node_model()