Source code for stellargraph.mapper.mini_batch_node_generators

# -*- coding: utf-8 -*-
#
# Copyright 2018-2020 Data61, CSIRO
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Mappers to provide input data for the graph models in layers.

"""
__all__ = ["ClusterNodeGenerator", "ClusterNodeSequence"]

import random
import copy
import numpy as np
import networkx as nx
from tensorflow.keras.utils import Sequence

from scipy import sparse
from ..core.graph import StellarGraph
from ..core.utils import is_real_iterable


[docs]class ClusterNodeGenerator: """ A data generator for use with ClusterGCN models on homogeneous graphs, [1]. The supplied graph G should be a StellarGraph object that is ready for machine learning. Currently the model requires node features to be available for all nodes in the graph. Use the :meth:`flow` method supplying the nodes and (optionally) targets to get an object that can be used as a Keras data generator. This generator will supply the features array and the adjacency matrix to a mini-batch Keras graph ML model. [1] `W. Chiang, X. Liu, S. Si, Y. Li, S. Bengio, C. Hsieh, 2019 <https://arxiv.org/abs/1905.07953>`_. For more information, please see the ClusterGCN demo: `<https://github.com/stellargraph/stellargraph/blob/master/demos/>`_ Args: G (StellarGraph): a machine-learning StellarGraph-type graph clusters (int or list): If int then it indicates the number of clusters (default is 1 that is the given graph). If clusters is greater than 1, then nodes are uniformly at random assigned to a cluster. If list, then it should be a list of lists of node IDs such that each list corresponds to a cluster of nodes in G. The clusters should be non-overlapping. q (float): The number of clusters to combine for each mini-batch. The default is 1. lam (float): The mixture coefficient for adjacency matrix normalisation. name (str): an optional name of the generator """ def __init__(self, G, clusters=1, q=1, lam=0.1, name=None): if not isinstance(G, StellarGraph): raise TypeError("Graph must be a StellarGraph or StellarDiGraph object.") self.graph = G self.name = name self.q = q # The number of clusters to sample per mini-batch self.lam = lam self.clusters = clusters if isinstance(clusters, list): self.k = len(clusters) elif isinstance(clusters, int): if clusters <= 0: raise ValueError( "{}: clusters must be greater than 0.".format(type(self).__name__) ) self.k = clusters else: raise TypeError( "{}: clusters must be either int or list type.".format( type(self).__name__ ) ) # Some error checking on the given parameter values if not isinstance(lam, float): raise TypeError("{}: lam must be a float type.".format(type(self).__name__)) if lam < 0 or lam > 1: raise ValueError( "{}: lam must be in the range [0, 1].".format(type(self).__name__) ) if not isinstance(q, int): raise TypeError("{}: q must be integer type.".format(type(self).__name__)) if q <= 0: raise ValueError( "{}: q must be greater than 0.".format(type(self).__name__) ) if self.k % q != 0: raise ValueError( "{}: the number of clusters must be exactly divisible by q.".format( type(self).__name__ ) ) # Check if the graph has features G.check_graph_for_ml() self.node_list = list(G.nodes()) # Check that there is only a single node type if len(G.node_types) > 1: raise ValueError( "{}: node generator requires graph with single node type; " "a graph with multiple node types is passed. Stopping.".format( type(self).__name__ ) ) if isinstance(clusters, int): # We are not given graph clusters. # We are going to split the graph into self.k random clusters all_nodes = list(G.nodes()) random.shuffle(all_nodes) cluster_size = len(all_nodes) // self.k self.clusters = [ all_nodes[i : i + cluster_size] for i in range(0, len(all_nodes), cluster_size) ] if len(self.clusters) > self.k: # for the case that the number of nodes is not exactly divisible by k, we combine # the last cluster with the second last one self.clusters[-2].extend(self.clusters[-1]) del self.clusters[-1] print(f"Number of clusters {self.k}") for i, c in enumerate(self.clusters): print(f"{i} cluster has size {len(c)}") # Get the features for the nodes self.features = G.node_features(self.node_list)
[docs] def flow(self, node_ids, targets=None, name=None): """ Creates a generator/sequence object for training, evaluation, or prediction with the supplied node ids and numeric targets. Args: node_ids (iterable): an iterable of node ids for the nodes of interest (e.g., training, validation, or test set nodes) targets (2d array, optional): a 2D array of numeric node targets with shape `(len(node_ids), target_size)` name (str, optional): An optional name for the returned generator object. Returns: A ClusterNodeSequence object to use with ClusterGCN in Keras methods :meth:`fit_generator`, :meth:`evaluate_generator`, and :meth:`predict_generator` """ if targets is not None: # Check targets is an iterable if not is_real_iterable(targets): raise TypeError( "{}: Targets must be an iterable or None".format( type(self).__name__ ) ) # Check targets correct shape if len(targets) != len(node_ids): raise ValueError( "{}: Targets must be the same length as node_ids".format( type(self).__name__ ) ) return ClusterNodeSequence( self.graph, self.clusters, targets=targets, node_ids=node_ids, q=self.q, lam=self.lam, name=name, )
class ClusterNodeSequence(Sequence): """ A Keras-compatible data generator for node inference using ClusterGCN model. Use this class with the Keras methods :meth:`keras.Model.fit_generator`, :meth:`keras.Model.evaluate_generator`, and :meth:`keras.Model.predict_generator`, This class should be created using the `.flow(...)` method of :class:`ClusterNodeGenerator`. Args: graph (StellarGraph): The graph clusters (list): A list of lists such that each sub-list indicates the nodes in a cluster. The length of this list, len(clusters) indicates the number of batches in one epoch. targets (np.ndarray, optional): An optional array of node targets of size (N x C), where C is the target size (e.g., number of classes for one-hot class targets) node_ids (iterable, optional): The node IDs for the target nodes. Required if targets is not None. normalize_adj (bool, optional): Specifies whether the adjacency matrix for each mini-batch should be normalized or not. The default is True. q (int, optional): The number of subgraphs to combine for each batch. The default value is 1 such that the generator treats each subgraph as a batch. lam (float, optional): The mixture coefficient for adjacency matrix normalisation (the 'diagonal enhancement' method). Valid values are in the interval [0, 1] and the default value is 0.1. name (str, optional): An optional name for this generator object. """ def __init__( self, graph, clusters, targets=None, node_ids=None, normalize_adj=True, q=1, lam=0.1, name=None, ): self.name = name self.clusters = list() self.clusters_original = copy.deepcopy(clusters) self.graph = graph self.node_list = list(graph.nodes()) self.normalize_adj = normalize_adj self.q = q self.lam = lam self.node_order = list() self._node_order_in_progress = list() self.__node_buffer = dict() self.target_ids = list() if len(clusters) % self.q != 0: raise ValueError( "The number of clusters should be exactly divisible by q. However, {} number of clusters is not exactly divisible by {}.".format( len(clusters), q ) ) if node_ids is not None: self.target_ids = list(node_ids) if targets is not None: if node_ids is None: raise ValueError( "Since targets is not None, node_ids must be given and cannot be None." ) if len(node_ids) != len(targets): raise ValueError( "When passed together targets and indices should be the same length." ) self.targets = np.asanyarray(targets) self.target_node_lookup = dict( zip(self.target_ids, range(len(self.target_ids))) ) else: self.targets = None self.on_epoch_end() def __len__(self): num_batches = len(self.clusters_original) // self.q return num_batches def __getitem__(self, index): # The next batch should be the adjacency matrix for the cluster and the corresponding feature vectors # and targets if available. cluster = self.clusters[index] adj_cluster = self.graph.to_adjacency_matrix(cluster) # The operations to normalize the adjacency matrix are too slow. # Either optimize this or implement as a layer(?) if self.normalize_adj: # add self loops adj_cluster.setdiag(1) # add self loops degree_matrix_diag = 1.0 / (adj_cluster.sum(axis=1) + 1) degree_matrix_diag = np.squeeze(np.asarray(degree_matrix_diag)) degree_matrix = sparse.lil_matrix(adj_cluster.shape) degree_matrix.setdiag(degree_matrix_diag) adj_cluster = degree_matrix.tocsr() @ adj_cluster adj_cluster.setdiag((1.0 + self.lam) * adj_cluster.diagonal()) adj_cluster = adj_cluster.toarray() g_node_list = list(cluster) # Determine the target nodes that exist in this cluster target_nodes_in_cluster = np.asanyarray( list(set(g_node_list).intersection(self.target_ids)) ) self.__node_buffer[index] = target_nodes_in_cluster # Dictionary to store node indices for quicker node index lookups node_lookup = dict(zip(g_node_list, range(len(g_node_list)))) # The list of indices of the target nodes in self.node_list target_node_indices = np.array( [node_lookup[n] for n in target_nodes_in_cluster] ) if index == (len(self.clusters_original) // self.q) - 1: # last batch self.__node_buffer_dict_to_list() cluster_targets = None # if self.targets is not None: # Dictionary to store node indices for quicker node index lookups # The list of indices of the target nodes in self.node_list cluster_target_indices = np.array( [self.target_node_lookup[n] for n in target_nodes_in_cluster] ) cluster_targets = self.targets[cluster_target_indices] cluster_targets = cluster_targets.reshape((1,) + cluster_targets.shape) features = self.graph.node_features(g_node_list) features = np.reshape(features, (1,) + features.shape) adj_cluster = adj_cluster.reshape((1,) + adj_cluster.shape) target_node_indices = target_node_indices[np.newaxis, np.newaxis, :] return [features, target_node_indices, adj_cluster], cluster_targets def __node_buffer_dict_to_list(self): self.node_order = [] for k, v in self.__node_buffer.items(): self.node_order.extend(v) def on_epoch_end(self): """ Shuffle all nodes at the end of each epoch """ if self.q > 1: # combine clusters cluster_indices = list(range(len(self.clusters_original))) random.shuffle(cluster_indices) self.clusters = [] for i in range(0, len(cluster_indices) - 1, self.q): cc = cluster_indices[i : i + self.q] tmp = [] for l in cc: tmp.extend(list(self.clusters_original[l])) self.clusters.append(tmp) else: self.clusters = copy.deepcopy(self.clusters_original) self.__node_buffer = dict() random.shuffle(self.clusters)