Execute this notebook:
Download locally
Link prediction with GCN¶
In this example, we use our implementation of the GCN algorithm to build a model that predicts citation links in the Cora dataset (see below). The problem is treated as a supervised link prediction problem on a homogeneous citation network with nodes representing papers (with attributes such as binary keyword indicators and categorical subject) and links corresponding to paper-paper citations.
To address this problem, we build a model with the following architecture. First we build a two-layer GCN model that takes labeled node pairs (citing-paper
-> cited-paper
) corresponding to possible citation links, and outputs a pair of node embeddings for the citing-paper
and cited-paper
nodes of the pair. These embeddings are then fed into a link classification layer, which first applies a binary operator to those node embeddings (e.g., concatenating them) to construct the
embedding of the potential link. Thus obtained link embeddings are passed through the dense link classification layer to obtain link predictions - probability for these candidate links to actually exist in the network. The entire model is trained end-to-end by minimizing the loss function of choice (e.g., binary cross-entropy between predicted link probabilities and true link labels, with true/false citation links having labels 1/0) using stochastic gradient descent (SGD) updates of the model
parameters, with minibatches of ‘training’ links fed into the model.
[3]:
import stellargraph as sg
from stellargraph.data import EdgeSplitter
from stellargraph.mapper import FullBatchLinkGenerator
from stellargraph.layer import GCN, LinkEmbedding
from tensorflow import keras
from sklearn import preprocessing, feature_extraction, model_selection
from stellargraph import globalvar
from stellargraph import datasets
from IPython.display import display, HTML
%matplotlib inline
Loading the CORA network data¶
(See the “Loading from Pandas” demo for details on how data can be loaded.)
[4]:
dataset = datasets.Cora()
display(HTML(dataset.description))
G, _ = dataset.load(subject_as_feature=True)
[5]:
print(G.info())
StellarGraph: Undirected multigraph
Nodes: 2708, Edges: 5429
Node types:
paper: [2708]
Features: float32 vector, length 1440
Edge types: paper-cites->paper
Edge types:
paper-cites->paper: [5429]
We aim to train a link prediction model, hence we need to prepare the train and test sets of links and the corresponding graphs with those links removed.
We are going to split our input graph into a train and test graphs using the EdgeSplitter class in stellargraph.data
. We will use the train graph for training the model (a binary classifier that, given two nodes, predicts whether a link between these two nodes should exist or not) and the test graph for evaluating the model’s performance on hold out data. Each of these graphs will have the same number of nodes as the input graph, but the number of links will differ (be reduced) as some of
the links will be removed during each split and used as the positive samples for training/testing the link prediction classifier.
From the original graph G, extract a randomly sampled subset of test edges (true and false citation links) and the reduced graph G_test with the positive test edges removed:
[6]:
# Define an edge splitter on the original graph G:
edge_splitter_test = EdgeSplitter(G)
# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from G, and obtain the
# reduced graph G_test with the sampled links removed:
G_test, edge_ids_test, edge_labels_test = edge_splitter_test.train_test_split(
p=0.1, method="global", keep_connected=True
)
** Sampled 542 positive and 542 negative edges. **
The reduced graph G_test, together with the test ground truth set of links (edge_ids_test, edge_labels_test), will be used for testing the model.
Now repeat this procedure to obtain the training data for the model. From the reduced graph G_test, extract a randomly sampled subset of train edges (true and false citation links) and the reduced graph G_train with the positive train edges removed:
[7]:
# Define an edge splitter on the reduced graph G_test:
edge_splitter_train = EdgeSplitter(G_test)
# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from G_test, and obtain the
# reduced graph G_train with the sampled links removed:
G_train, edge_ids_train, edge_labels_train = edge_splitter_train.train_test_split(
p=0.1, method="global", keep_connected=True
)
** Sampled 488 positive and 488 negative edges. **
G_train, together with the train ground truth set of links (edge_ids_train, edge_labels_train), will be used for training the model.
Creating the GCN link model¶
Next, we create the link generators for the train and test link examples to the model. The link generators take the pairs of nodes (citing-paper
, cited-paper
) that are given in the .flow
method to the Keras model, together with the corresponding binary labels indicating whether those pairs represent true or false links.
The number of epochs for training the model:
[8]:
epochs = 50
For training we create a generator on the G_train
graph, and make an iterator over the training links using the generator’s flow()
method:
[9]:
train_gen = FullBatchLinkGenerator(G_train, method="gcn")
train_flow = train_gen.flow(edge_ids_train, edge_labels_train)
Using GCN (local pooling) filters...
[10]:
test_gen = FullBatchLinkGenerator(G_test, method="gcn")
test_flow = train_gen.flow(edge_ids_test, edge_labels_test)
Using GCN (local pooling) filters...
Now we can specify our machine learning model, we need a few more parameters for this:
the
layer_sizes
is a list of hidden feature sizes of each layer in the model. In this example we use two GCN layers with 16-dimensional hidden node features at each layer.activations
is a list of activations applied to each layer’s outputdropout=0.3
specifies a 30% dropout at each layer.
We create a GCN model as follows:
[11]:
gcn = GCN(
layer_sizes=[16, 16], activations=["relu", "relu"], generator=train_gen, dropout=0.3
)
To create a Keras model we now expose the input and output tensors of the GCN model for link prediction, via the GCN.in_out_tensors
method:
[12]:
x_inp, x_out = gcn.in_out_tensors()
Final link classification layer that takes a pair of node embeddings produced by the GCN model, applies a binary operator to them to produce the corresponding link embedding (ip
for inner product; other options for the binary operator can be seen by running a cell with ?LinkEmbedding
in it), and passes it through a dense layer:
[13]:
prediction = LinkEmbedding(activation="relu", method="ip")(x_out)
The predictions need to be reshaped from (X, 1)
to (X,)
to match the shape of the targets we have supplied above.
[14]:
prediction = keras.layers.Reshape((-1,))(prediction)
Stack the GCN and prediction layers into a Keras model, and specify the loss
[15]:
model = keras.Model(inputs=x_inp, outputs=prediction)
model.compile(
optimizer=keras.optimizers.Adam(lr=0.01),
loss=keras.losses.binary_crossentropy,
metrics=["acc"],
)
Evaluate the initial (untrained) model on the train and test set:
[16]:
init_train_metrics = model.evaluate(train_flow)
init_test_metrics = model.evaluate(test_flow)
print("\nTrain Set Metrics of the initial (untrained) model:")
for name, val in zip(model.metrics_names, init_train_metrics):
print("\t{}: {:0.4f}".format(name, val))
print("\nTest Set Metrics of the initial (untrained) model:")
for name, val in zip(model.metrics_names, init_test_metrics):
print("\t{}: {:0.4f}".format(name, val))
['...']
1/1 [==============================] - 0s 121ms/step - loss: 1.8927 - acc: 0.5000
['...']
1/1 [==============================] - 0s 9ms/step - loss: 1.8621 - acc: 0.5000
Train Set Metrics of the initial (untrained) model:
loss: 1.8927
acc: 0.5000
Test Set Metrics of the initial (untrained) model:
loss: 1.8621
acc: 0.5000
Train the model:
[17]:
history = model.fit(
train_flow, epochs=epochs, validation_data=test_flow, verbose=2, shuffle=False
)
['...']
['...']
Train for 1 steps, validate for 1 steps
Epoch 1/50
1/1 - 1s - loss: 1.7886 - acc: 0.5000 - val_loss: 1.5024 - val_acc: 0.5387
Epoch 2/50
1/1 - 0s - loss: 1.7260 - acc: 0.5400 - val_loss: 0.6822 - val_acc: 0.6070
Epoch 3/50
1/1 - 0s - loss: 0.8526 - acc: 0.5953 - val_loss: 0.7401 - val_acc: 0.5729
Epoch 4/50
1/1 - 0s - loss: 0.7397 - acc: 0.5666 - val_loss: 0.7479 - val_acc: 0.5849
Epoch 5/50
1/1 - 0s - loss: 0.7334 - acc: 0.5799 - val_loss: 0.6777 - val_acc: 0.6172
Epoch 6/50
1/1 - 0s - loss: 0.6413 - acc: 0.6404 - val_loss: 0.6981 - val_acc: 0.6375
Epoch 7/50
1/1 - 0s - loss: 0.7289 - acc: 0.6568 - val_loss: 0.6576 - val_acc: 0.6448
Epoch 8/50
1/1 - 0s - loss: 0.6367 - acc: 0.6568 - val_loss: 0.6846 - val_acc: 0.6338
Epoch 9/50
1/1 - 0s - loss: 0.6111 - acc: 0.6639 - val_loss: 0.6850 - val_acc: 0.6384
Epoch 10/50
1/1 - 0s - loss: 0.5818 - acc: 0.6855 - val_loss: 0.6667 - val_acc: 0.6513
Epoch 11/50
1/1 - 0s - loss: 0.5721 - acc: 0.6916 - val_loss: 0.6304 - val_acc: 0.6688
Epoch 12/50
1/1 - 0s - loss: 0.5422 - acc: 0.7551 - val_loss: 0.6461 - val_acc: 0.7048
Epoch 13/50
1/1 - 0s - loss: 0.5791 - acc: 0.7695 - val_loss: 0.6710 - val_acc: 0.7002
Epoch 14/50
1/1 - 0s - loss: 0.4987 - acc: 0.7838 - val_loss: 0.6632 - val_acc: 0.7131
Epoch 15/50
1/1 - 0s - loss: 0.5537 - acc: 0.7920 - val_loss: 0.7022 - val_acc: 0.7168
Epoch 16/50
1/1 - 0s - loss: 0.5463 - acc: 0.7807 - val_loss: 0.7353 - val_acc: 0.7251
Epoch 17/50
1/1 - 0s - loss: 0.5315 - acc: 0.7910 - val_loss: 0.7022 - val_acc: 0.7223
Epoch 18/50
1/1 - 0s - loss: 0.4832 - acc: 0.7930 - val_loss: 0.6777 - val_acc: 0.7251
Epoch 19/50
1/1 - 0s - loss: 0.4477 - acc: 0.8105 - val_loss: 0.6668 - val_acc: 0.7242
Epoch 20/50
1/1 - 0s - loss: 0.4439 - acc: 0.7971 - val_loss: 0.6176 - val_acc: 0.7196
Epoch 21/50
1/1 - 0s - loss: 0.3993 - acc: 0.8309 - val_loss: 0.6136 - val_acc: 0.7196
Epoch 22/50
1/1 - 0s - loss: 0.3830 - acc: 0.8248 - val_loss: 0.6248 - val_acc: 0.7196
Epoch 23/50
1/1 - 0s - loss: 0.4062 - acc: 0.8473 - val_loss: 0.6505 - val_acc: 0.7205
Epoch 24/50
1/1 - 0s - loss: 0.4259 - acc: 0.8504 - val_loss: 0.6313 - val_acc: 0.7232
Epoch 25/50
1/1 - 0s - loss: 0.3858 - acc: 0.8504 - val_loss: 0.6221 - val_acc: 0.7232
Epoch 26/50
1/1 - 0s - loss: 0.3439 - acc: 0.8596 - val_loss: 0.6356 - val_acc: 0.7196
Epoch 27/50
1/1 - 0s - loss: 0.3333 - acc: 0.8709 - val_loss: 0.6512 - val_acc: 0.7205
Epoch 28/50
1/1 - 0s - loss: 0.3255 - acc: 0.8760 - val_loss: 0.6791 - val_acc: 0.7232
Epoch 29/50
1/1 - 0s - loss: 0.3593 - acc: 0.8791 - val_loss: 0.7117 - val_acc: 0.7214
Epoch 30/50
1/1 - 0s - loss: 0.3251 - acc: 0.8873 - val_loss: 0.7323 - val_acc: 0.7242
Epoch 31/50
1/1 - 0s - loss: 0.3256 - acc: 0.8770 - val_loss: 0.7427 - val_acc: 0.7288
Epoch 32/50
1/1 - 0s - loss: 0.3088 - acc: 0.9037 - val_loss: 0.7509 - val_acc: 0.7297
Epoch 33/50
1/1 - 0s - loss: 0.3048 - acc: 0.8934 - val_loss: 0.7523 - val_acc: 0.7371
Epoch 34/50
1/1 - 0s - loss: 0.2989 - acc: 0.8996 - val_loss: 0.7425 - val_acc: 0.7380
Epoch 35/50
1/1 - 0s - loss: 0.2847 - acc: 0.9047 - val_loss: 0.7396 - val_acc: 0.7362
Epoch 36/50
1/1 - 0s - loss: 0.2645 - acc: 0.9016 - val_loss: 0.7313 - val_acc: 0.7380
Epoch 37/50
1/1 - 0s - loss: 0.2811 - acc: 0.8975 - val_loss: 0.7350 - val_acc: 0.7362
Epoch 38/50
1/1 - 0s - loss: 0.2720 - acc: 0.9078 - val_loss: 0.6788 - val_acc: 0.7389
Epoch 39/50
1/1 - 0s - loss: 0.2603 - acc: 0.8986 - val_loss: 0.6679 - val_acc: 0.7371
Epoch 40/50
1/1 - 0s - loss: 0.2580 - acc: 0.9047 - val_loss: 0.6692 - val_acc: 0.7408
Epoch 41/50
1/1 - 0s - loss: 0.2809 - acc: 0.8955 - val_loss: 0.6916 - val_acc: 0.7408
Epoch 42/50
1/1 - 0s - loss: 0.2540 - acc: 0.9016 - val_loss: 0.7552 - val_acc: 0.7435
Epoch 43/50
1/1 - 0s - loss: 0.2629 - acc: 0.9139 - val_loss: 0.8007 - val_acc: 0.7445
Epoch 44/50
1/1 - 0s - loss: 0.2614 - acc: 0.9273 - val_loss: 0.8633 - val_acc: 0.7445
Epoch 45/50
1/1 - 0s - loss: 0.2316 - acc: 0.9057 - val_loss: 0.8980 - val_acc: 0.7500
Epoch 46/50
1/1 - 0s - loss: 0.2204 - acc: 0.9242 - val_loss: 0.9062 - val_acc: 0.7472
Epoch 47/50
1/1 - 0s - loss: 0.2326 - acc: 0.9160 - val_loss: 0.9067 - val_acc: 0.7537
Epoch 48/50
1/1 - 0s - loss: 0.2358 - acc: 0.9334 - val_loss: 0.8805 - val_acc: 0.7601
Epoch 49/50
1/1 - 0s - loss: 0.2196 - acc: 0.9211 - val_loss: 0.8471 - val_acc: 0.7592
Epoch 50/50
1/1 - 0s - loss: 0.2102 - acc: 0.9221 - val_loss: 0.8198 - val_acc: 0.7620
Plot the training history:
[18]:
sg.utils.plot_history(history)

Evaluate the trained model on test citation links:
[19]:
train_metrics = model.evaluate(train_flow)
test_metrics = model.evaluate(test_flow)
print("\nTrain Set Metrics of the trained model:")
for name, val in zip(model.metrics_names, train_metrics):
print("\t{}: {:0.4f}".format(name, val))
print("\nTest Set Metrics of the trained model:")
for name, val in zip(model.metrics_names, test_metrics):
print("\t{}: {:0.4f}".format(name, val))
['...']
1/1 [==============================] - 0s 9ms/step - loss: 0.1409 - acc: 0.9641
['...']
1/1 [==============================] - 0s 9ms/step - loss: 0.8198 - acc: 0.7620
Train Set Metrics of the trained model:
loss: 0.1409
acc: 0.9641
Test Set Metrics of the trained model:
loss: 0.8198
acc: 0.7620
Execute this notebook:
Download locally