# -*- coding: utf-8 -*-
#
# Copyright 2018-2020 Data61, CSIRO
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Mappers to provide input data for the graph models in layers.
"""
__all__ = ["ClusterNodeGenerator", "ClusterNodeSequence"]
import random
import copy
import numpy as np
import networkx as nx
from tensorflow.keras.utils import Sequence
from scipy import sparse
from ..core.graph import StellarGraph
from ..core.utils import is_real_iterable, normalize_adj
from ..connector.neo4j.graph import Neo4jStellarGraph
from .base import Generator
[docs]class ClusterNodeGenerator(Generator):
"""
A data generator for use with GCN, GAT and APPNP models on homogeneous graphs, see [1].
The supplied graph G should be a StellarGraph object with node features.
Use the :meth:`flow` method supplying the nodes and (optionally) targets
to get an object that can be used as a Keras data generator.
This generator will supply the features array and the adjacency matrix to a
mini-batch Keras graph ML model.
[1] `W. Chiang, X. Liu, S. Si, Y. Li, S. Bengio, C. Hsieh, 2019 <https://arxiv.org/abs/1905.07953>`_.
.. seealso::
Models using this generator: :class:`.GCN`, :class:`.GAT`, :class:`.APPNP`.
Examples using this generator:
- `Cluster-GCN node classification <https://stellargraph.readthedocs.io/en/stable/demos/node-classification/cluster-gcn-node-classification.html>`__
- `Cluster-GCN node classification with Neo4j <https://stellargraph.readthedocs.io/en/stable/demos/connector/neo4j/cluster-gcn-on-cora-neo4j-example.html>`__
- `unsupervised representation learning with Deep Graph Infomax <https://stellargraph.readthedocs.io/en/stable/demos/embeddings/deep-graph-infomax-embeddings.html>`__
Args:
G (StellarGraph): a machine-learning StellarGraph-type graph
clusters (int or list, optional): If int, it indicates the number of clusters (default is 1, corresponding to the entire graph).
If `clusters` is greater than 1, then nodes are randomly assigned to a cluster.
If list, then it should be a list of lists of node IDs, such that each list corresponds to a cluster of nodes
in `G`. The clusters should be non-overlapping.
q (int, optional): The number of clusters to combine for each mini-batch (default is 1).
The total number of clusters must be divisible by `q`.
lam (float, optional): The mixture coefficient for adjacency matrix normalisation (default is 0.1).
Valid values are in the interval [0, 1].
weighted (bool, optional): if True, use the edge weights from ``G``; if False, treat the
graph as unweighted.
name (str, optional): Name for the node generator.
"""
def __init__(self, G, clusters=1, q=1, lam=0.1, weighted=False, name=None):
if not isinstance(G, (StellarGraph, Neo4jStellarGraph)):
raise TypeError("Graph must be a StellarGraph or StellarDiGraph object.")
self.graph = G
self.name = name
self.q = q # The number of clusters to sample per mini-batch
self.lam = lam
self.clusters = clusters
self.method = "cluster_gcn"
self.multiplicity = 1
self.use_sparse = False
self.weighted = weighted
if isinstance(clusters, list):
self.k = len(clusters)
elif isinstance(clusters, int):
if clusters <= 0:
raise ValueError(
"{}: clusters must be greater than 0.".format(type(self).__name__)
)
self.k = clusters
else:
raise TypeError(
"{}: clusters must be either int or list type.".format(
type(self).__name__
)
)
# Some error checking on the given parameter values
if not isinstance(lam, float):
raise TypeError("{}: lam must be a float type.".format(type(self).__name__))
if lam < 0 or lam > 1:
raise ValueError(
"{}: lam must be in the range [0, 1].".format(type(self).__name__)
)
if not isinstance(q, int):
raise TypeError("{}: q must be integer type.".format(type(self).__name__))
if q <= 0:
raise ValueError(
"{}: q must be greater than 0.".format(type(self).__name__)
)
if self.k % q != 0:
raise ValueError(
"{}: the number of clusters must be exactly divisible by q.".format(
type(self).__name__
)
)
self.node_list = list(G.nodes())
# if graph is a StellarGraph check that the graph has features
G.check_graph_for_ml(expensive_check=False)
# Check that there is only a single node type
_ = G.unique_node_type(
"G: expected a graph with a single node type, found a graph with node types: %(found)s"
)
if isinstance(clusters, int):
# We are not given graph clusters.
# We are going to split the graph into self.k random clusters
all_nodes = list(G.nodes())
random.shuffle(all_nodes)
cluster_size = len(all_nodes) // self.k
self.clusters = [
all_nodes[i : i + cluster_size]
for i in range(0, len(all_nodes), cluster_size)
]
if len(self.clusters) > self.k:
# for the case that the number of nodes is not exactly divisible by k, we combine
# the last cluster with the second last one
self.clusters[-2].extend(self.clusters[-1])
del self.clusters[-1]
print(f"Number of clusters {self.k}")
for i, c in enumerate(self.clusters):
print(f"{i} cluster has size {len(c)}")
# Store the features of one node to allow graph ML models to peak at the feature dimension
# FIXME 1621: store feature_dimension here instead of features. This must also update ClusterGCN, and all
# fullbactch methods and generators
self.features = G.node_features(self.node_list[:1])
[docs] def num_batch_dims(self):
return 2
[docs] def flow(self, node_ids, targets=None, name=None):
"""
Creates a generator/sequence object for training, evaluation, or prediction
with the supplied node ids and numeric targets.
Args:
node_ids (iterable): an iterable of node ids for the nodes of interest
(e.g., training, validation, or test set nodes)
targets (2d array, optional): a 2D array of numeric node targets with shape ``(len(node_ids),
target_size)``
name (str, optional): An optional name for the returned generator object.
Returns:
A :class:`ClusterNodeSequence` object to use with :class:`.GCN`, :class:`.GAT` or :class:`.APPNP` in Keras
methods :meth:`fit`, :meth:`evaluate`, and :meth:`predict`.
"""
if targets is not None:
# Check targets is an iterable
if not is_real_iterable(targets):
raise TypeError(
"{}: Targets must be an iterable or None".format(
type(self).__name__
)
)
# Check targets correct shape
if len(targets) != len(node_ids):
raise ValueError(
"{}: Targets must be the same length as node_ids".format(
type(self).__name__
)
)
return ClusterNodeSequence(
self.graph,
self.clusters,
targets=targets,
node_ids=node_ids,
q=self.q,
lam=self.lam,
weighted=self.weighted,
name=name,
)
class ClusterNodeSequence(Sequence):
"""
A Keras-compatible data generator for node inference using ClusterGCN model.
Use this class with the Keras methods :meth:`keras.Model.fit`,
:meth:`keras.Model.evaluate`, and
:meth:`keras.Model.predict`.
This class should be created using the :meth:`flow` method of
:class:`.ClusterNodeGenerator`.
Args:
graph (StellarGraph): The graph
clusters (list): A list of lists such that each sub-list indicates the nodes in a cluster.
The length of this list, len(clusters) indicates the number of batches in one epoch.
targets (np.ndarray, optional): An optional array of node targets of size (N x C),
where C is the target size (e.g., number of classes for one-hot class targets)
node_ids (iterable, optional): The node IDs for the target nodes. Required if targets is not None.
normalize_adj (bool, optional): Specifies whether the adjacency matrix for each mini-batch should
be normalized or not. The default is True.
q (int, optional): The number of subgraphs to combine for each batch. The default value is
1 such that the generator treats each subgraph as a batch.
lam (float, optional): The mixture coefficient for adjacency matrix normalisation (the
'diagonal enhancement' method). Valid values are in the interval [0, 1] and the default value is 0.1.
name (str, optional): An optional name for this generator object.
"""
def __init__(
self,
graph,
clusters,
targets=None,
node_ids=None,
normalize_adj=True,
q=1,
lam=0.1,
weighted=False,
name=None,
):
self.name = name
self.clusters = list()
self.clusters_original = copy.deepcopy(clusters)
self.graph = graph
self.node_list = list(graph.nodes())
self.normalize_adj = normalize_adj
self.q = q
self.lam = lam
self.weighted = weighted
self.node_order = list()
self._node_order_in_progress = list()
self.__node_buffer = dict()
self.target_ids = list()
if len(clusters) % self.q != 0:
raise ValueError(
"The number of clusters should be exactly divisible by q. However, {} number of clusters is not exactly divisible by {}.".format(
len(clusters), q
)
)
if node_ids is not None:
self.target_ids = list(node_ids)
if targets is not None:
if node_ids is None:
raise ValueError(
"Since targets is not None, node_ids must be given and cannot be None."
)
if len(node_ids) != len(targets):
raise ValueError(
"When passed together targets and indices should be the same length."
)
self.targets = np.asanyarray(targets)
self.target_node_lookup = dict(
zip(self.target_ids, range(len(self.target_ids)))
)
else:
self.targets = None
self.on_epoch_end()
def __len__(self):
num_batches = len(self.clusters_original) // self.q
return num_batches
def _diagonal_enhanced_normalization(self, adj_cluster):
# Cluster-GCN normalization is:
# A~ + λdiag(A~) where A~ = N(A + I) with normalization factor N = (D + I)^(-1)
#
# Expands to:
# NA + NI + λN(diag(A) + I) =
# NA + N(I + λ(diag(A) + I)) =
# NA + λN(diag(A) + (1 + 1/λ)I))
#
# (This could potentially become a layer, to benefit from a GPU.)
degrees = np.asarray(adj_cluster.sum(axis=1)).ravel()
normalization = 1 / (degrees + 1)
# NA: multiply rows manually
norm_adj = adj_cluster.multiply(normalization[:, None]).toarray()
# λN(diag(A) + (1 + 1/λ)I): work with the diagonals directly
diag = np.diag(norm_adj)
diag_addition = (
normalization * self.lam * (adj_cluster.diagonal() + (1 + 1 / self.lam))
)
np.fill_diagonal(norm_adj, diag + diag_addition)
return norm_adj
def __getitem__(self, index):
# The next batch should be the adjacency matrix for the cluster and the corresponding feature vectors
# and targets if available.
cluster = self.clusters[index]
adj_cluster = self.graph.to_adjacency_matrix(cluster, weighted=self.weighted)
if self.normalize_adj:
adj_cluster = self._diagonal_enhanced_normalization(adj_cluster)
else:
adj_cluster = adj_cluster.toarray()
g_node_list = list(cluster)
# Determine the target nodes that exist in this cluster
target_nodes_in_cluster = np.asanyarray(
list(set(g_node_list).intersection(self.target_ids))
)
self.__node_buffer[index] = target_nodes_in_cluster
# Dictionary to store node indices for quicker node index lookups
node_lookup = dict(zip(g_node_list, range(len(g_node_list))))
# The list of indices of the target nodes in self.node_list
target_node_indices = np.array(
[node_lookup[n] for n in target_nodes_in_cluster]
)
if index == (len(self.clusters_original) // self.q) - 1:
# last batch
self.__node_buffer_dict_to_list()
cluster_targets = None
#
if self.targets is not None:
# Dictionary to store node indices for quicker node index lookups
# The list of indices of the target nodes in self.node_list
cluster_target_indices = np.array(
[self.target_node_lookup[n] for n in target_nodes_in_cluster],
dtype=np.int64,
)
cluster_targets = self.targets[cluster_target_indices]
cluster_targets = cluster_targets.reshape((1,) + cluster_targets.shape)
features = self.graph.node_features(g_node_list)
features = np.reshape(features, (1,) + features.shape)
adj_cluster = adj_cluster.reshape((1,) + adj_cluster.shape)
target_node_indices = target_node_indices[np.newaxis, :]
return [features, target_node_indices, adj_cluster], cluster_targets
def __node_buffer_dict_to_list(self):
self.node_order = []
for k, v in self.__node_buffer.items():
self.node_order.extend(v)
def on_epoch_end(self):
"""
Shuffle all nodes at the end of each epoch
"""
if self.q > 1:
# combine clusters
cluster_indices = list(range(len(self.clusters_original)))
random.shuffle(cluster_indices)
self.clusters = []
for i in range(0, len(cluster_indices) - 1, self.q):
cc = cluster_indices[i : i + self.q]
tmp = []
for l in cc:
tmp.extend(list(self.clusters_original[l]))
self.clusters.append(tmp)
else:
self.clusters = copy.deepcopy(self.clusters_original)
self.__node_buffer = dict()
random.shuffle(self.clusters)