# -*- coding: utf-8 -*-
#
# Copyright 2020 Data61, CSIRO
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..core.graph import StellarGraph
from ..core.utils import is_real_iterable
from .sequences import PaddedGraphSequence
from .base import Generator
[docs]class PaddedGraphGenerator(Generator):
"""
A data generator for use with graph classification algorithms.
The supplied graphs should be :class:`StellarGraph` objects ready for machine learning. The generator
requires node features to be available for all nodes in the graph.
Use the :meth:`flow` method supplying the graph indexes and (optionally) targets
to get an object that can be used as a Keras data generator.
This generator supplies the features arrays and the adjacency matrices to a mini-batch Keras
graph classification model. Differences in the number of nodes are resolved by padding each
batch of features and adjacency matrices, and supplying a boolean mask indicating which are
valid and which are padding.
Args:
graphs (list): a collection of ready for machine-learning StellarGraph-type objects
name (str): an optional name of the generator
"""
def __init__(self, graphs, name=None):
self.node_features_size = None
for graph in graphs:
if not isinstance(graph, StellarGraph):
raise TypeError(
f"graphs: expected every element to be a StellarGraph object, found {type(graph).__name__}."
)
if len(graph.node_types) > 1:
raise ValueError(
"graphs: node generator requires graphs with single node type, "
f"found a graph with {len(graph.node_types)} node types."
)
graph.check_graph_for_ml()
# we require that all graphs have node features of the same dimensionality
f_dim = graph.node_feature_sizes()[list(graph.node_types)[0]]
if self.node_features_size is None:
self.node_features_size = f_dim
elif self.node_features_size != f_dim:
raise ValueError(
"graphs: expected node features for all graph to have same dimensions,"
f"found {self.node_features_size} vs {f_dim}"
)
self.graphs = graphs
self.name = name
[docs] def num_batch_dims(self):
return 1
[docs] def flow(
self,
graph_ilocs,
targets=None,
symmetric_normalization=True,
batch_size=1,
name=None,
):
"""
Creates a generator/sequence object for training, evaluation, or prediction
with the supplied graph indexes and targets.
Args:
graph_ilocs (iterable): an iterable of graph indexes in self.graphs for the graphs of interest
(e.g., training, validation, or test set nodes).
targets (2d array, optional): a 2D array of numeric graph targets with shape `(len(graph_ilocs),
len(targets))`.
symmetric_normalization (bool, optional): The type of normalization to be applied on the graph adjacency
matrices. If True, the adjacency matrix is left and right multiplied by the inverse square root of the
degree matrix; otherwise, the adjacency matrix is only left multiplied by the inverse of the degree
matrix.
batch_size (int, optional): The batch size.
name (str, optional): An optional name for the returned generator object.
Returns:
A :class:`PaddedGraphSequence` object to use with Keras methods :meth:`fit`, :meth:`evaluate`, and :meth:`predict`
"""
if targets is not None:
# Check targets is an iterable
if not is_real_iterable(targets):
raise TypeError(
f"targets: expected an iterable or None object, found {type(targets).__name__}"
)
# Check targets correct shape
if len(targets) != len(graph_ilocs):
raise ValueError(
f"expected targets to be the same length as node_ids, found {len(targets)} vs {len(graph_ilocs)}"
)
if not isinstance(batch_size, int):
raise TypeError(
f"expected batch_size to be integer type, found {type(batch_size).__name__}"
)
if batch_size <= 0:
raise ValueError(
f"expected batch_size to be strictly positive integer, found {batch_size}"
)
return PaddedGraphSequence(
graphs=[self.graphs[i] for i in graph_ilocs],
targets=targets,
symmetric_normalization=symmetric_normalization,
batch_size=batch_size,
name=name,
)