"""
Mappers to provide input data for the graph models in layers.

"""
__all__ = ["ClusterNodeGenerator", "ClusterNodeSequence"]

import random
import copy
import numpy as np
import networkx as nx
from tensorflow.keras.utils import Sequence

from scipy import sparse
from ..core.graph import StellarGraph
from ..connector.neo4j.graph import Neo4jStellarGraph
from .base import Generator

[docs]class ClusterNodeGenerator(Generator):
"""
A data generator for use with GCN, GAT and APPNP models on homogeneous graphs, see [1].

The supplied graph G should be a StellarGraph object with node features.
Use the :meth:flow method supplying the nodes and (optionally) targets
to get an object that can be used as a Keras data generator.

This generator will supply the features array and the adjacency matrix to a
mini-batch Keras graph ML model.

[1] W. Chiang, X. Liu, S. Si, Y. Li, S. Bengio, C. Hsieh, 2019 <https://arxiv.org/abs/1905.07953>_.

.. seealso::

Models using this generator: :class:.GCN, :class:.GAT, :class:.APPNP.

Examples using this generator:

- Cluster-GCN node classification <https://stellargraph.readthedocs.io/en/stable/demos/node-classification/cluster-gcn-node-classification.html>__
- Cluster-GCN node classification with Neo4j <https://stellargraph.readthedocs.io/en/stable/demos/connector/neo4j/cluster-gcn-on-cora-neo4j-example.html>__
- unsupervised representation learning with Deep Graph Infomax <https://stellargraph.readthedocs.io/en/stable/demos/embeddings/deep-graph-infomax-embeddings.html>__

Args:
G (StellarGraph): a machine-learning StellarGraph-type graph
clusters (int or list, optional): If int, it indicates the number of clusters (default is 1, corresponding to the entire graph).
If clusters is greater than 1, then nodes are randomly assigned to a cluster.
If list, then it should be a list of lists of node IDs, such that each list corresponds to a cluster of nodes
in G. The clusters should be non-overlapping.
q (int, optional): The number of clusters to combine for each mini-batch (default is 1).
The total number of clusters must be divisible by q.
lam (float, optional): The mixture coefficient for adjacency matrix normalisation (default is 0.1).
Valid values are in the interval [0, 1].
weighted (bool, optional): if True, use the edge weights from G; if False, treat the
graph as unweighted.
name (str, optional): Name for the node generator.
"""

def __init__(self, G, clusters=1, q=1, lam=0.1, weighted=False, name=None):

if not isinstance(G, (StellarGraph, Neo4jStellarGraph)):
raise TypeError("Graph must be a StellarGraph or StellarDiGraph object.")

self.graph = G
self.name = name
self.q = q  # The number of clusters to sample per mini-batch
self.lam = lam
self.clusters = clusters
self.method = "cluster_gcn"
self.multiplicity = 1
self.use_sparse = False
self.weighted = weighted

if isinstance(clusters, list):
self.k = len(clusters)
elif isinstance(clusters, int):
if clusters <= 0:
raise ValueError(
"{}: clusters must be greater than 0.".format(type(self).__name__)
)
self.k = clusters
else:
raise TypeError(
"{}: clusters must be either int or list type.".format(
type(self).__name__
)
)

# Some error checking on the given parameter values
if not isinstance(lam, float):
raise TypeError("{}: lam must be a float type.".format(type(self).__name__))

if lam < 0 or lam > 1:
raise ValueError(
"{}: lam must be in the range [0, 1].".format(type(self).__name__)
)

if not isinstance(q, int):
raise TypeError("{}: q must be integer type.".format(type(self).__name__))

if q <= 0:
raise ValueError(
"{}: q must be greater than 0.".format(type(self).__name__)
)

if self.k % q != 0:
raise ValueError(
"{}: the number of clusters must be exactly divisible by q.".format(
type(self).__name__
)
)

self.node_list = list(G.nodes())

# if graph is a StellarGraph check that the graph has features
G.check_graph_for_ml(expensive_check=False)
# Check that there is only a single node type
_ = G.unique_node_type(
"G: expected a graph with a single node type, found a graph with node types: %(found)s"
)
if isinstance(clusters, int):
# We are not given graph clusters.
# We are going to split the graph into self.k random clusters
all_nodes = list(G.nodes())
random.shuffle(all_nodes)
cluster_size = len(all_nodes) // self.k
self.clusters = [
all_nodes[i : i + cluster_size]
for i in range(0, len(all_nodes), cluster_size)
]
if len(self.clusters) > self.k:
# for the case that the number of nodes is not exactly divisible by k, we combine
# the last cluster with the second last one
self.clusters[-2].extend(self.clusters[-1])
del self.clusters[-1]

print(f"Number of clusters {self.k}")
for i, c in enumerate(self.clusters):
print(f"{i} cluster has size {len(c)}")

# Store the features of one node to allow graph ML models to peak at the feature dimension
# FIXME 1621: store feature_dimension here instead of features. This must also update ClusterGCN, and all
# fullbactch methods and generators
self.features = G.node_features(self.node_list[:1])

[docs]    def num_batch_dims(self):
return 2

[docs]    def flow(self, node_ids, targets=None, name=None):
"""
Creates a generator/sequence object for training, evaluation, or prediction
with the supplied node ids and numeric targets.

Args:
node_ids (iterable): an iterable of node ids for the nodes of interest
(e.g., training, validation, or test set nodes)
targets (2d array, optional): a 2D array of numeric node targets with shape (len(node_ids),
target_size)
name (str, optional): An optional name for the returned generator object.

Returns:
A :class:ClusterNodeSequence object to use with :class:.GCN, :class:.GAT or :class:.APPNP in Keras
methods :meth:fit, :meth:evaluate, and :meth:predict.

"""
if targets is not None:
# Check targets is an iterable
if not is_real_iterable(targets):
raise TypeError(
"{}: Targets must be an iterable or None".format(
type(self).__name__
)
)

# Check targets correct shape
if len(targets) != len(node_ids):
raise ValueError(
"{}: Targets must be the same length as node_ids".format(
type(self).__name__
)
)

return ClusterNodeSequence(
self.graph,
self.clusters,
targets=targets,
node_ids=node_ids,
q=self.q,
lam=self.lam,
weighted=self.weighted,
name=name,
)

[docs]    def default_corrupt_input_index_groups(self):
return [[0]]

class ClusterNodeSequence(Sequence):
"""
A Keras-compatible data generator for node inference using ClusterGCN model.
Use this class with the Keras methods :meth:keras.Model.fit,
:meth:keras.Model.evaluate, and
:meth:keras.Model.predict.

This class should be created using the :meth:flow method of
:class:.ClusterNodeGenerator.

Args:
graph (StellarGraph): The graph
clusters (list): A list of lists such that each sub-list indicates the nodes in a cluster.
The length of this list, len(clusters) indicates the number of batches in one epoch.
targets (np.ndarray, optional): An optional array of node targets of size (N x C),
where C is the target size (e.g., number of classes for one-hot class targets)
node_ids (iterable, optional): The node IDs for the target nodes. Required if targets is not None.
normalize_adj (bool, optional): Specifies whether the adjacency matrix for each mini-batch should
be normalized or not. The default is True.
q (int, optional): The number of subgraphs to combine for each batch. The default value is
1 such that the generator treats each subgraph as a batch.
lam (float, optional): The mixture coefficient for adjacency matrix normalisation (the
'diagonal enhancement' method). Valid values are in the interval [0, 1] and the default value is 0.1.
name (str, optional): An optional name for this generator object.
"""

def __init__(
self,
graph,
clusters,
targets=None,
node_ids=None,
q=1,
lam=0.1,
weighted=False,
name=None,
):

self.name = name
self.clusters = list()
self.clusters_original = copy.deepcopy(clusters)
self.graph = graph
self.node_list = list(graph.nodes())
self.q = q
self.lam = lam
self.weighted = weighted
self.node_order = list()
self._node_order_in_progress = list()
self.__node_buffer = dict()
self.target_ids = list()

if len(clusters) % self.q != 0:
raise ValueError(
"The number of clusters should be exactly divisible by q. However, {} number of clusters is not exactly divisible by {}.".format(
len(clusters), q
)
)

if node_ids is not None:
self.target_ids = list(node_ids)

if targets is not None:
if node_ids is None:
raise ValueError(
"Since targets is not None, node_ids must be given and cannot be None."
)

if len(node_ids) != len(targets):
raise ValueError(
"When passed together targets and indices should be the same length."
)

self.targets = np.asanyarray(targets)
self.target_node_lookup = dict(
zip(self.target_ids, range(len(self.target_ids)))
)
else:
self.targets = None

self.on_epoch_end()

def __len__(self):
num_batches = len(self.clusters_original) // self.q
return num_batches

# Cluster-GCN normalization is:
#     A~ + λdiag(A~) where A~ = N(A + I) with normalization factor N = (D + I)^(-1)
#
# Expands to:
#     NA + NI + λN(diag(A) + I) =
#     NA + N(I + λ(diag(A) + I)) =
#     NA + λN(diag(A) + (1 + 1/λ)I))
#
# (This could potentially become a layer, to benefit from a GPU.)
normalization = 1 / (degrees + 1)

# NA: multiply rows manually

# λN(diag(A) + (1 + 1/λ)I): work with the diagonals directly
normalization * self.lam * (adj_cluster.diagonal() + (1 + 1 / self.lam))
)

def __getitem__(self, index):
# The next batch should be the adjacency matrix for the cluster and the corresponding feature vectors
# and targets if available.
cluster = self.clusters[index]

else:

g_node_list = list(cluster)

# Determine the target nodes that exist in this cluster
target_nodes_in_cluster = np.asanyarray(
list(set(g_node_list).intersection(self.target_ids))
)

self.__node_buffer[index] = target_nodes_in_cluster

# Dictionary to store node indices for quicker node index lookups
node_lookup = dict(zip(g_node_list, range(len(g_node_list))))

# The list of indices of the target nodes in self.node_list
target_node_indices = np.array(
[node_lookup[n] for n in target_nodes_in_cluster]
)

if index == (len(self.clusters_original) // self.q) - 1:
# last batch
self.__node_buffer_dict_to_list()

cluster_targets = None
#
if self.targets is not None:
# Dictionary to store node indices for quicker node index lookups
# The list of indices of the target nodes in self.node_list
cluster_target_indices = np.array(
[self.target_node_lookup[n] for n in target_nodes_in_cluster],
dtype=np.int64,
)
cluster_targets = self.targets[cluster_target_indices]
cluster_targets = cluster_targets.reshape((1,) + cluster_targets.shape)

features = self.graph.node_features(g_node_list)

features = np.reshape(features, (1,) + features.shape)
target_node_indices = target_node_indices[np.newaxis, :]

def __node_buffer_dict_to_list(self):
self.node_order = []
for k, v in self.__node_buffer.items():
self.node_order.extend(v)

def on_epoch_end(self):
"""
Shuffle all nodes at the end of each epoch
"""
if self.q > 1:
# combine clusters
cluster_indices = list(range(len(self.clusters_original)))
random.shuffle(cluster_indices)
self.clusters = []

for i in range(0, len(cluster_indices) - 1, self.q):
cc = cluster_indices[i : i + self.q]
tmp = []
for l in cc:
tmp.extend(list(self.clusters_original[l]))
self.clusters.append(tmp)
else:
self.clusters = copy.deepcopy(self.clusters_original)

self.__node_buffer = dict()

random.shuffle(self.clusters)