# -*- coding: utf-8 -*-
#
# Copyright 2020 Data61, CSIRO
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import GCN, GAT, APPNP, PPNP, GraphSAGE, DirectedGraphSAGE
from .misc import deprecated_model_function
from ..mapper import CorruptedGenerator
from tensorflow.keras.layers import Input, Lambda, Layer, GlobalAveragePooling1D
import tensorflow as tf
from tensorflow.keras import backend as K
import warnings
import numpy as np
__all__ = ["DeepGraphInfomax", "DGIDiscriminator"]
[docs]class DGIDiscriminator(Layer):
"""
This Layer computes the Discriminator function for Deep Graph Infomax (https://arxiv.org/pdf/1809.10341.pdf).
.. seealso:: :class:`.DeepGraphInfomax` uses this layer.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs] def build(self, input_shapes):
first_size = input_shapes[0][-1]
second_size = input_shapes[1][-1]
self.kernel = self.add_weight(
shape=(first_size, second_size),
initializer="glorot_uniform",
name="kernel",
regularizer=None,
constraint=None,
)
self.built = True
[docs] def call(self, inputs):
"""
Applies the layer to the inputs.
Args:
inputs: a list or tuple of tensors with shapes ``[(1, N, F), (1, F)]`` for full batch methods and shapes
``[(B, F), (F,)]`` for sampled node methods, containing the node features and a summary feature vector.
Where ``N`` is the number of nodes in the graph, ``F`` is the feature dimension, and ``B`` is the batch size.
Returns:
a Tensor with shape ``(1, N)`` for full batch methods and shape ``(B,)`` for sampled node methods.
"""
features, summary = inputs
score = tf.linalg.matvec(features, tf.linalg.matvec(self.kernel, summary),)
return score
class DGIReadout(Layer):
"""
This Layer computes the Readout function for Deep Graph Infomax (https://arxiv.org/pdf/1809.10341.pdf).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def build(self, input_shapes):
self.built = True
def call(self, node_feats):
"""
Applies the layer to the inputs.
Args:
node_feats: a tensor containing the batch node features from the base model. This has shape `(1, N, F)`
for full batch methods and shape `(B, F)` for sampled node methods. Where `N` is the number of nodes
in the graph, `F` is the feature dimension, and `B` is the batch size.
Returns:
a Tensor with shape `(1, F)` for full batch methods and shape `(F,)` for sampled node methods.
"""
summary = tf.reduce_mean(node_feats, axis=-2)
summary = tf.math.sigmoid(summary)
return summary
[docs]class DeepGraphInfomax:
"""
A class to wrap stellargraph models for Deep Graph Infomax unsupervised training
(https://arxiv.org/pdf/1809.10341.pdf).
.. seealso::
Examples using Deep Graph Infomax:
- `unsupervised representation learning <https://stellargraph.readthedocs.io/en/stable/demos/embeddings/deep-graph-infomax-embeddings.html>`__
- `semi-supervised node classification <https://stellargraph.readthedocs.io/en/stable/demos/node-classification/gcn-deep-graph-infomax-fine-tuning-node-classification.html>`__
Appropriate data generator: :class:`.CorruptedGenerator`.
Args:
base_model: the base stellargraph model class
"""
def __init__(self, base_model, corrupted_generator=None):
if corrupted_generator is None:
warnings.warn(
"The 'corrupted_generator' parameter should be set to an instance of `CorruptedGenerator`, because the support for specific algorithms is being replaced by a more general form",
DeprecationWarning,
stacklevel=2,
)
if isinstance(base_model, (GCN, GAT, APPNP, PPNP)):
self._corruptible_inputs_idxs = [0]
elif isinstance(base_model, DirectedGraphSAGE):
self._corruptible_inputs_idxs = np.arange(base_model.max_slots)
elif isinstance(base_model, GraphSAGE):
self._corruptible_inputs_idxs = np.arange(base_model.max_hops + 1)
else:
raise TypeError(
f"base_model: expected GCN, GAT, APPNP, PPNP, GraphSAGE,"
f"or DirectedGraphSAGE, found {type(base_model).__name__}"
)
elif not isinstance(corrupted_generator, CorruptedGenerator):
raise TypeError(
f"corrupted_generator: expected a CorruptedGenerator, found {type(corrupted_generator).__name__}"
)
else:
self._corruptible_inputs_idxs = [
idx
for group in corrupted_generator.corrupt_index_groups
for idx in group
]
self.base_model = base_model
self._node_feats = None
self._discriminator = DGIDiscriminator()
[docs] def in_out_tensors(self):
"""
A function to create the the Keras inputs and outputs for a Deep Graph Infomax model for unsupervised training.
Note that the :func:`tensorflow.nn.sigmoid_cross_entropy_with_logits` loss must be used with this model.
Example::
dg_infomax = DeepGraphInfoMax(...)
x_in, x_out = dg_infomax.in_out_tensors()
model = Model(inputs=x_in, outputs=x_out)
model.compile(loss=tf.nn.sigmoid_cross_entropy_with_logits, ...)
Returns:
input and output layers for use with a Keras model
"""
x_inp, node_feats = self.base_model.in_out_tensors()
x_corr = [
Input(batch_shape=x_inp[i].shape) for i in self._corruptible_inputs_idxs
]
# shallow copy normal inputs and replace corruptible inputs with new inputs
x_in_corr = x_inp.copy()
for i, x in zip(self._corruptible_inputs_idxs, x_corr):
x_in_corr[i] = x
node_feats_corr = self.base_model(x_in_corr)
summary = DGIReadout()(node_feats)
scores = self._discriminator([node_feats, summary])
scores_corrupted = self._discriminator([node_feats_corr, summary])
x_out = tf.stack([scores, scores_corrupted], axis=-1)
return x_corr + x_inp, x_out
[docs] def embedding_model(self):
"""
Deprecated: use ``base_model.in_out_tensors`` instead. Deep Graph Infomax just trains the base model,
and the model behaves as usual after training.
"""
warnings.warn(
f"The 'embedding_model' method is deprecated, use 'base_model.in_out_tensors' instead.",
DeprecationWarning,
stacklevel=2,
)
# these tensors should link into the weights that get trained by `build`
x_emb_in, x_emb_out = self.base_model.in_out_tensors()
# squeeze out batch dim of full batch models
if len(x_emb_out.shape) == 3:
squeeze_layer = Lambda(lambda x: K.squeeze(x, axis=0), name="squeeze")
x_emb_out = squeeze_layer(x_emb_out)
return x_emb_in, x_emb_out
build = deprecated_model_function(in_out_tensors, "build")