# -*- coding: utf-8 -*-
#
# Copyright 2020 Data61, CSIRO
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import GCN, GAT, APPNP, PPNP
from .misc import deprecated_model_function
from tensorflow.keras.layers import Input, Lambda, Layer, GlobalAveragePooling1D
import tensorflow as tf
from tensorflow.keras import backend as K
import warnings
__all__ = ["DeepGraphInfomax", "DGIDiscriminator"]
[docs]class DGIDiscriminator(Layer):
"""
This Layer computes the Discriminator function for Deep Graph Infomax (https://arxiv.org/pdf/1809.10341.pdf).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs] def build(self, input_shapes):
first_size = input_shapes[0][-1]
second_size = input_shapes[1][-1]
self.kernel = self.add_weight(
shape=(first_size, second_size),
initializer="glorot_uniform",
name="kernel",
regularizer=None,
constraint=None,
)
self.built = True
[docs] def call(self, inputs):
"""
Applies the layer to the inputs.
Args:
inputs: a list or tuple of tensors with shapes [(batch, N, F1), (batch, F2)] containing the node features and a
summary feature vector.
Returns:
a Tensor with shape (1, N)
"""
features, summary = inputs
score = tf.linalg.matvec(features, tf.linalg.matvec(self.kernel, summary),)
return score
[docs]class DeepGraphInfomax:
"""
A class to wrap stellargraph models for Deep Graph Infomax unsupervised training
(https://arxiv.org/pdf/1809.10341.pdf).
Args:
base_model: the base stellargraph model class
"""
def __init__(self, base_model):
if not isinstance(base_model, (GCN, GAT, APPNP, PPNP)):
raise TypeError(
f"base_model: expected GCN, GAT, APPNP or PPNP found {type(base_model).__name__}"
)
if base_model.multiplicity != 1:
warnings.warn(
f"base_model: expected a node model (multiplicity = 1), found a link model (multiplicity = {base_model.multiplicity}). Base model tensors will be constructed as for a node model.",
stacklevel=2,
)
self.base_model = base_model
self._node_feats = None
# specific to full batch models
self._corruptible_inputs_idxs = [0]
self._discriminator = DGIDiscriminator()
[docs] def in_out_tensors(self):
"""
A function to create the the keras inputs and outputs for a Deep Graph Infomax model for unsupervised training.
Note that the tf.nn.sigmoid_cross_entropy_with_logits loss must be used with this model.
Example::
dg_infomax = DeepGraphInfoMax(...)
x_in, x_out = dg_infomax.in_out_tensors()
model = Model(inputs=x_in, outputs=x_out)
model.compile(loss=tf.nn.sigmoid_cross_entropy_with_logits, ...)
Returns:
input and output layers for use with a keras model
"""
x_inp, node_feats = self.base_model.in_out_tensors(multiplicity=1)
x_corr = [
Input(batch_shape=x_inp[i].shape) for i in self._corruptible_inputs_idxs
]
# shallow copy normal inputs and replace corruptible inputs with new inputs
x_in_corr = x_inp.copy()
for i, x in zip(self._corruptible_inputs_idxs, x_corr):
x_in_corr[i] = x
node_feats_corr = self.base_model(x_in_corr)
summary = tf.keras.activations.sigmoid(GlobalAveragePooling1D()(node_feats))
scores = self._discriminator([node_feats, summary])
scores_corrupted = self._discriminator([node_feats_corr, summary])
x_out = tf.stack([scores, scores_corrupted], axis=2)
return x_corr + x_inp, x_out
[docs] def embedding_model(self):
"""
A function to create the the inputs and outputs for an embedding model.
Returns:
input and output layers for use with a keras model
"""
# these tensors should link into the weights that get trained by `build`
x_emb_in, x_emb_out = self.base_model.in_out_tensors(multiplicity=1)
squeeze_layer = Lambda(lambda x: K.squeeze(x, axis=0), name="squeeze")
x_emb_out = squeeze_layer(x_emb_out)
return x_emb_in, x_emb_out
build = deprecated_model_function(in_out_tensors, "build")