"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example, we use our implementation of the [GraphSAGE](http://snap.stanford.edu/graphsage/) algorithm to build a model that predicts citation links in the PubMed-Diabetes dataset (see below). The problem is treated as a supervised link prediction problem on a homogeneous citation network with nodes representing papers (with attributes such as binary keyword indicators and categorical subject) and links corresponding to paper-paper citations. \n",
"\n",
"To address this problem, we build a model with the following architecture. First we build a two-layer GraphSAGE model that takes labeled `(paper1, paper2)` node pairs corresponding to possible citation links, and outputs a pair of node embeddings for the `paper1` and `paper2` nodes of the pair. These embeddings are then fed into a link classification layer, which first applies a binary operator to those node embeddings (e.g., concatenating them) to construct the embedding of the potential link. Thus obtained link embeddings are passed through the dense link classification layer to obtain link predictions - probability for these candidate links to actually exist in the network. The entire model is trained end-to-end by minimizing the loss function of choice (e.g., binary cross-entropy between predicted link probabilities and true link labels, with true/false citation links having labels 1/0) using stochastic gradient descent (SGD) updates of the model parameters, with minibatches of 'training' links fed into the model.\n",
"\n",
"Lastly, we investigate the nature of prediction probabilities. We want to know if GraphSAGE's prediction probabilities are well calibrated or not. In the latter case, we present two methods for calibrating the model's output.\n",
"\n",
"**References**\n",
"\n",
"1. Inductive Representation Learning on Large Graphs. W.L. Hamilton, R. Ying, and J. Leskovec arXiv:1706.02216 \n",
"[cs.SI], 2017. ([link](http://snap.stanford.edu/graphsage/))\n",
"\n",
"2. On Calibration of Modern Neural Networks. C. Guo, G. Pleiss, Y. Sun, and K. Q. Weinberger. \n",
"ICML 2017. ([link](https://geoffpleiss.com/nn_calibration))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading the PubMed Diabetes network data"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"nbsphinx": "hidden",
"tags": [
"CloudRunner"
]
},
"outputs": [],
"source": [
"# install StellarGraph if running on Google Colab\n",
"import sys\n",
"if 'google.colab' in sys.modules:\n",
" %pip install -q stellargraph[demos]==1.0.0"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"nbsphinx": "hidden",
"tags": [
"VersionCheck"
]
},
"outputs": [],
"source": [
"# verify that we're using the correct version of StellarGraph for this notebook\n",
"import stellargraph as sg\n",
"\n",
"try:\n",
" sg.utils.validate_notebook_version(\"1.0.0\")\n",
"except AttributeError:\n",
" raise ValueError(\n",
" f\"This notebook requires StellarGraph version 1.0.0, but a different version {sg.__version__} is installed. Please see .\"\n",
" ) from None"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import networkx as nx\n",
"import pandas as pd\n",
"import numpy as np\n",
"import itertools\n",
"import os\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import stellargraph as sg\n",
"from stellargraph.data import EdgeSplitter\n",
"from stellargraph.mapper import GraphSAGELinkGenerator\n",
"from stellargraph.layer import GraphSAGE, link_classification\n",
"from stellargraph.calibration import expected_calibration_error, plot_reliability_diagram\n",
"from stellargraph.calibration import IsotonicCalibration, TemperatureCalibration\n",
"\n",
"from tensorflow import keras\n",
"from sklearn import preprocessing, feature_extraction, model_selection\n",
"from sklearn.calibration import calibration_curve\n",
"from sklearn.isotonic import IsotonicRegression\n",
"\n",
"from sklearn.metrics import accuracy_score\n",
"\n",
"from stellargraph import globalvar\n",
"from stellargraph import datasets\n",
"from IPython.display import display, HTML\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Global parameters\n",
"Specify the minibatch size (number of node pairs per minibatch) and the number of epochs for training the model:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"tags": [
"parameters"
]
},
"outputs": [],
"source": [
"batch_size = 50\n",
"epochs = 20 # The number of training epochs for training the GraphSAGE model.\n",
"\n",
"# train, test, validation split\n",
"train_size = 0.2\n",
"test_size = 0.15\n",
"val_size = 0.2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading the PubMed Diabetes network data"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": [
"DataLoadingLinks"
]
},
"source": [
"(See [the \"Loading from Pandas\" demo](../basics/loading-pandas.ipynb) for details on how data can be loaded.)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"tags": [
"DataLoading"
]
},
"outputs": [
{
"data": {
"text/html": [
"The PubMed Diabetes dataset consists of 19717 scientific publications from PubMed database pertaining to diabetes classified into one of three classes. The citation network consists of 44338 links. Each publication in the dataset is described by a TF/IDF weighted word vector from a dictionary which consists of 500 unique words."
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dataset = datasets.PubMedDiabetes()\n",
"display(HTML(dataset.description))\n",
"G, _subjects = dataset.load()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"StellarGraph: Undirected multigraph\n",
" Nodes: 19717, Edges: 44338\n",
"\n",
" Node types:\n",
" paper: [19717]\n",
" Features: float32 vector, length 500\n",
" Edge types: paper-cites->paper\n",
"\n",
" Edge types:\n",
" paper-cites->paper: [44338]\n"
]
}
],
"source": [
"print(G.info())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We aim to train a link prediction model, hence we need to prepare the train and test sets of links and the corresponding graphs with those links removed.\n",
"\n",
"We are going to split our input graph into a train and test graphs using the EdgeSplitter class in `stellargraph.data`. We will use the train graph for training the model (a binary classifier that, given two nodes, predicts whether a link between these two nodes should exist or not) and the test graph for evaluating the model's performance on hold out data.\n",
"Each of these graphs will have the same number of nodes as the input graph, but the number of links will differ (be reduced) as some of the links will be removed during each split and used as the positive samples for training/testing the link prediction classifier."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From the original graph G, extract a randomly sampled subset of validation edges (true and false citation links) and the reduced graph G_test with the positive test edges removed:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Removed 1000 edges\n",
"Removed 2000 edges\n",
"Removed 3000 edges\n",
"Removed 4000 edges\n",
"Removed 5000 edges\n",
"Removed 6000 edges\n",
"Sampled 1000 negative examples\n",
"Sampled 2000 negative examples\n",
"Sampled 3000 negative examples\n",
"Sampled 4000 negative examples\n",
"Sampled 5000 negative examples\n",
"Sampled 6000 negative examples\n",
"** Sampled 6650 positive and 6650 negative edges. **\n"
]
}
],
"source": [
"# Define an edge splitter on the original graph G:\n",
"edge_splitter_test = EdgeSplitter(G)\n",
"\n",
"# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from G, and obtain the\n",
"# reduced graph G_test with the sampled links removed:\n",
"G_test, edge_ids_test, edge_labels_test = edge_splitter_test.train_test_split(\n",
" p=test_size, method=\"global\", keep_connected=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The reduced graph G_test, together with the test ground truth set of links (edge_ids_test, edge_labels_test), will be used for testing the model.\n",
"\n",
"Now repeat this procedure to obtain the validation data for the model. From the reduced graph G_test, extract a randomly sampled subset of validation edges (true and false citation links) and the reduced graph G_val with the positive train edges removed:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Removed 1000 edges\n",
"Removed 2000 edges\n",
"Removed 3000 edges\n",
"Removed 4000 edges\n",
"Removed 5000 edges\n",
"Removed 6000 edges\n",
"Removed 7000 edges\n",
"Sampled 1000 negative examples\n",
"Sampled 2000 negative examples\n",
"Sampled 3000 negative examples\n",
"Sampled 4000 negative examples\n",
"Sampled 5000 negative examples\n",
"Sampled 6000 negative examples\n",
"Sampled 7000 negative examples\n",
"** Sampled 7537 positive and 7537 negative edges. **\n"
]
}
],
"source": [
"# Define an edge splitter on the reduced graph G_test:\n",
"edge_splitter_val = EdgeSplitter(G_test)\n",
"\n",
"# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from G_test, and obtain the\n",
"# reduced graph G_train with the sampled links removed:\n",
"G_val, edge_ids_val, edge_labels_val = edge_splitter_val.train_test_split(\n",
" p=val_size, method=\"global\", keep_connected=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The reduced graph G_val, together with the validation ground truth set of links (edge_ids_val, edge_labels_val), will be used for validating the model (can also be used to tune the model parameters).\n",
"\n",
"Now repeat this procedure to obtain the training data for the model. From the reduced graph G_val, extract a randomly sampled subset of train edges (true and false citation links) and the reduced graph G_train with the positive train edges removed:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Removed 1000 edges\n",
"Removed 2000 edges\n",
"Removed 3000 edges\n",
"Removed 4000 edges\n",
"Removed 5000 edges\n",
"Removed 6000 edges\n",
"Sampled 1000 negative examples\n",
"Sampled 2000 negative examples\n",
"Sampled 3000 negative examples\n",
"Sampled 4000 negative examples\n",
"Sampled 5000 negative examples\n",
"Sampled 6000 negative examples\n",
"** Sampled 6030 positive and 6030 negative edges. **\n"
]
}
],
"source": [
"# Define an edge splitter on the reduced graph G_test:\n",
"edge_splitter_train = EdgeSplitter(G_val)\n",
"\n",
"# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from G_test, and obtain the\n",
"# reduced graph G_train with the sampled links removed:\n",
"G_train, edge_ids_train, edge_labels_train = edge_splitter_train.train_test_split(\n",
" p=train_size, method=\"global\", keep_connected=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"G_train, together with the train ground truth set of links (edge_ids_train, edge_labels_train), will be used for training the model."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Summary of G_train, G_val and G_test - note that they have the same set of nodes, only differing in their edge sets:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"StellarGraph: Undirected multigraph\n",
" Nodes: 19717, Edges: 24121\n",
"\n",
" Node types:\n",
" paper: [19717]\n",
" Features: float32 vector, length 500\n",
" Edge types: paper-cites->paper\n",
"\n",
" Edge types:\n",
" paper-cites->paper: [24121]\n"
]
}
],
"source": [
"print(G_train.info())"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"StellarGraph: Undirected multigraph\n",
" Nodes: 19717, Edges: 30151\n",
"\n",
" Node types:\n",
" paper: [19717]\n",
" Features: float32 vector, length 500\n",
" Edge types: paper-cites->paper\n",
"\n",
" Edge types:\n",
" paper-cites->paper: [30151]\n"
]
}
],
"source": [
"print(G_val.info())"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"StellarGraph: Undirected multigraph\n",
" Nodes: 19717, Edges: 37688\n",
"\n",
" Node types:\n",
" paper: [19717]\n",
" Features: float32 vector, length 500\n",
" Edge types: paper-cites->paper\n",
"\n",
" Edge types:\n",
" paper-cites->paper: [37688]\n"
]
}
],
"source": [
"print(G_test.info())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we create the link generators for sampling and streaming train and test link examples to the model. The link generators essentially \"map\" pairs of nodes `(paper1, paper2)` to the input of GraphSAGE: they take minibatches of node pairs, sample 2-hop subgraphs with `(paper1, paper2)` head nodes extracted from those pairs, and feed them, together with the corresponding binary labels indicating whether those pairs represent true or false citation links, to the input layer of the GraphSAGE model, for SGD updates of the model parameters."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"\n",
"Specify the sizes of 1- and 2-hop neighbour samples for GraphSAGE:\n",
"\n",
"Note that the length of `num_samples` list defines the number of layers/iterations in the GraphSAGE model. In this example, we are defining a 2-layer GraphSAGE model."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"num_samples = [10, 5]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"train_gen = GraphSAGELinkGenerator(G_train, batch_size, num_samples)\n",
"val_gen = GraphSAGELinkGenerator(G_val, batch_size, num_samples)\n",
"test_gen = GraphSAGELinkGenerator(G_test, batch_size, num_samples)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"GraphSAGE part of the model, with hidden layer sizes of 50 for both GraphSAGE layers, a bias term, and dropout.\n",
"\n",
"Note that the length of layer_sizes list must be equal to the length of num_samples, as len(num_samples) defines the number of hops (layers) in the GraphSAGE model."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"layer_sizes = [32, 32]\n",
"graphsage = GraphSAGE(\n",
" layer_sizes=layer_sizes, generator=train_gen, bias=True, dropout=0.2\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# Build the model and expose input and output sockets of graphsage, for node pair inputs:\n",
"x_inp, x_out = graphsage.in_out_tensors()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Final link classification layer that takes a pair of node embeddings produced by graphsage, applies a binary operator to them to produce the corresponding link embedding ('ip' for inner product; other options for the binary operator can be seen by running a cell with `?link_classification` in it), and passes it through a dense layer:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"link_classification: using 'ip' method to combine node embeddings into edge embeddings\n"
]
}
],
"source": [
"logits = link_classification(\n",
" output_dim=1, output_act=\"linear\", edge_embedding_method=\"ip\"\n",
")(x_out)\n",
"\n",
"prediction = keras.layers.Activation(keras.activations.sigmoid)(logits)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Stack the GraphSAGE and prediction layers into a Keras model, and specify the loss"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"model = keras.Model(inputs=x_inp, outputs=prediction)\n",
"\n",
"model.compile(\n",
" optimizer=keras.optimizers.Adam(lr=1e-3),\n",
" loss=keras.losses.binary_crossentropy,\n",
" metrics=[keras.metrics.binary_accuracy],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Evaluate the initial (untrained) model on the train, val and test sets:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"train_flow = train_gen.flow(edge_ids_train, edge_labels_train, shuffle=True)\n",
"val_flow = val_gen.flow(edge_ids_val, edge_labels_val)\n",
"test_flow = test_gen.flow(edge_ids_test, edge_labels_test)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" ['...']\n",
"242/242 [==============================] - 12s 50ms/step - loss: 0.6779 - binary_accuracy: 0.5118\n",
" ['...']\n",
"302/302 [==============================] - 17s 57ms/step - loss: 0.6753 - binary_accuracy: 0.5125 2s -\n",
" ['...']\n",
"266/266 [==============================] - 17s 64ms/step - loss: 0.6748 - binary_accuracy: 0.5102\n",
"\n",
"Train Set Metrics of the initial (untrained) model:\n",
"\tloss: 0.6779\n",
"\tbinary_accuracy: 0.5118\n",
"\n",
"Validation Set Metrics of the initial (untrained) model:\n",
"\tloss: 0.6753\n",
"\tbinary_accuracy: 0.5125\n",
"\n",
"Test Set Metrics of the initial (untrained) model:\n",
"\tloss: 0.6748\n",
"\tbinary_accuracy: 0.5102\n"
]
}
],
"source": [
"init_train_metrics = model.evaluate(train_flow)\n",
"init_val_metrics = model.evaluate(val_flow)\n",
"init_test_metrics = model.evaluate(test_flow)\n",
"\n",
"print(\"\\nTrain Set Metrics of the initial (untrained) model:\")\n",
"for name, val in zip(model.metrics_names, init_train_metrics):\n",
" print(\"\\t{}: {:0.4f}\".format(name, val))\n",
"\n",
"print(\"\\nValidation Set Metrics of the initial (untrained) model:\")\n",
"for name, val in zip(model.metrics_names, init_val_metrics):\n",
" print(\"\\t{}: {:0.4f}\".format(name, val))\n",
"\n",
"print(\"\\nTest Set Metrics of the initial (untrained) model:\")\n",
"for name, val in zip(model.metrics_names, init_test_metrics):\n",
" print(\"\\t{}: {:0.4f}\".format(name, val))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train the model:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" ['...']\n",
" ['...']\n"
]
}
],
"source": [
"history = model.fit(\n",
" train_flow, epochs=epochs, validation_data=val_flow, verbose=0, shuffle=True,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Plot the training history:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"
"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"sg.utils.plot_history(history)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Evaluate the trained model on test citation links:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" ['...']\n",
"242/242 [==============================] - 13s 53ms/step - loss: 0.4606 - binary_accuracy: 0.8850\n",
" ['...']\n",
"302/302 [==============================] - 17s 58ms/step - loss: 0.5606 - binary_accuracy: 0.7306\n",
" ['...']\n",
"266/266 [==============================] - 17s 66ms/step - loss: 0.5574 - binary_accuracy: 0.7349\n",
"\n",
"Train Set Metrics of the trained model:\n",
"\tloss: 0.4606\n",
"\tbinary_accuracy: 0.8850\n",
"\n",
"Validation Set Metrics of the trained model:\n",
"\tloss: 0.5606\n",
"\tbinary_accuracy: 0.7306\n",
"\n",
"Test Set Metrics of the trained model:\n",
"\tloss: 0.5574\n",
"\tbinary_accuracy: 0.7349\n"
]
}
],
"source": [
"train_metrics = model.evaluate(train_flow)\n",
"val_metrics = model.evaluate(val_flow)\n",
"test_metrics = model.evaluate(test_flow)\n",
"\n",
"print(\"\\nTrain Set Metrics of the trained model:\")\n",
"for name, val in zip(model.metrics_names, train_metrics):\n",
" print(\"\\t{}: {:0.4f}\".format(name, val))\n",
"\n",
"print(\"\\nValidation Set Metrics of the trained model:\")\n",
"for name, val in zip(model.metrics_names, val_metrics):\n",
" print(\"\\t{}: {:0.4f}\".format(name, val))\n",
"\n",
"print(\"\\nTest Set Metrics of the trained model:\")\n",
"for name, val in zip(model.metrics_names, test_metrics):\n",
" print(\"\\t{}: {:0.4f}\".format(name, val))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"num_tests = 1 # the number of times to generate predictions"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"266/266 [==============================] - 18s 67ms/step\n"
]
}
],
"source": [
"all_test_predictions = [\n",
" model.predict(test_flow, verbose=True) for _ in np.arange(num_tests)\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Diagnosing model miscalibration\n",
"\n",
"We are going to use method from scikit-learn.calibration module to calibrate the binary classifier."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"calibration_data = [\n",
" calibration_curve(\n",
" y_prob=test_predictions, y_true=edge_labels_test, n_bins=10, normalize=True\n",
" )\n",
" for test_predictions in all_test_predictions\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let' calculate the expected calibration error on the test set before calibration."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ECE: (before calibration) 0.3583\n"
]
}
],
"source": [
"for fraction_of_positives, mean_predicted_value in calibration_data:\n",
" ece_pre_calibration = expected_calibration_error(\n",
" prediction_probabilities=all_test_predictions[0],\n",
" accuracy=fraction_of_positives,\n",
" confidence=mean_predicted_value,\n",
" )\n",
" print(\"ECE: (before calibration) {:.4f}\".format(ece_pre_calibration))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's plot the reliability diagram. This is a visual aid for the diagnosis of a poorly calibrated binary classifier."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"
"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_reliability_diagram(\n",
" calibration_data, np.array(all_test_predictions[0]), ece=[ece_pre_calibration]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Calibration \n",
"\n",
"Next, we are going to use our validation set to calibrate the model.\n",
"\n",
"We will consider two different approaches for calibrating a binary classifier, Platt scaling and Isotonic regression.\n",
"\n",
"## Platt Scaling\n",
"\n",
"$q_i = \\sigma(\\alpha z_i+\\beta)$ where $z_i$ is the GraphSAGE output (before the last layer's activation function is applied), $q_i$ is the calibrated probability, and $\\sigma()$ is the sigmoid function. \n",
"\n",
"$\\alpha$ and $\\beta$ are the model's trainable parameters.\n",
"\n",
"For more information see:\n",
"- https://en.wikipedia.org/wiki/Platt_scaling\n",
"\n",
"## Isotonic Regression\n",
"\n",
"Isotonic Regression is a regression technique that fits a piece-wise, non-decreasing, linear function to data. For more information see:\n",
"- https://scikit-learn.org/stable/modules/generated/sklearn.isotonic.IsotonicRegression.html#sklearn.isotonic.IsotonicRegression\n",
"- https://en.wikipedia.org/wiki/Isotonic_regression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Select the calibration method."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"use_platt = False # True for Platt scaling or False for Isotonic Regression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For simplicity, we are going to calibrate using a single prediction per query point."
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"num_tests = 1"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"score_model = keras.Model(inputs=x_inp, outputs=logits)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"302/302 [==============================] - 18s 59ms/step\n",
"266/266 [==============================] - 15s 55ms/step\n"
]
}
],
"source": [
"if use_platt:\n",
" all_val_score_predictions = [\n",
" score_model.predict(val_flow, verbose=True) for _ in np.arange(num_tests)\n",
" ]\n",
" all_test_score_predictions = [\n",
" score_model.predict(test_flow, verbose=True) for _ in np.arange(num_tests)\n",
" ]\n",
" all_test_probabilistic_predictions = [\n",
" model.predict(test_flow, verbose=True) for _ in np.arange(num_tests)\n",
" ]\n",
"else:\n",
" all_val_score_predictions = [\n",
" model.predict(val_flow, verbose=True) for _ in np.arange(num_tests)\n",
" ]\n",
" all_test_probabilistic_predictions = [\n",
" model.predict(test_flow, verbose=True) for _ in np.arange(num_tests)\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(15074, 1)"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"val_predictions = np.mean(np.array(all_val_score_predictions), axis=0)\n",
"val_predictions.shape"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"# These are the uncalibrated prediction probabilities.\n",
"if use_platt:\n",
" test_predictions = np.mean(np.array(all_test_score_predictions), axis=0)\n",
" test_predictions.shape\n",
"else:\n",
" test_predictions = np.mean(np.array(all_test_probabilistic_predictions), axis=0)\n",
" test_predictions.shape"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"if use_platt:\n",
" # for binary classification this class performs Platt Scaling\n",
" lr = TemperatureCalibration()\n",
"else:\n",
" lr = IsotonicCalibration()"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((15074, 1), (15074,))"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"val_predictions.shape, edge_labels_val.shape"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"lr.fit(val_predictions, edge_labels_val)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"lr_test_predictions = lr.predict(test_predictions)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(13300, 1)"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lr_test_predictions.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's check if these predictions are calibrated!\n",
"\n",
"If calibration is successful then the ECE after calibration will be lower and the calibration curve will track the ideal diagonal line more closely."
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"calibration_data = [\n",
" calibration_curve(\n",
" y_prob=lr_test_predictions, y_true=edge_labels_test, n_bins=10, normalize=True\n",
" )\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ECE (after calibration): 0.0128\n"
]
}
],
"source": [
"for fraction_of_positives, mean_predicted_value in calibration_data:\n",
" ece_post_calibration = expected_calibration_error(\n",
" prediction_probabilities=lr_test_predictions,\n",
" accuracy=fraction_of_positives,\n",
" confidence=mean_predicted_value,\n",
" )\n",
" print(\"ECE (after calibration): {:.4f}\".format(ece_post_calibration))"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"
"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_reliability_diagram(\n",
" calibration_data, lr_test_predictions, ece=[ece_post_calibration]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As a final test, check if the accuracy of the model changes after calibration."
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy of model before calibration: 0.74\n"
]
}
],
"source": [
"y_pred = np.zeros(len(test_predictions))\n",
"if use_platt:\n",
" # the true predictions are the probabilistic outputs\n",
" test_predictions = np.mean(np.array(all_test_probabilistic_predictions), axis=0)\n",
"y_pred[test_predictions.reshape(-1) > 0.5] = 1\n",
"print(\n",
" \"Accuracy of model before calibration: {:.2f}\".format(\n",
" accuracy_score(y_pred=y_pred, y_true=edge_labels_test)\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy for model after calibration: 0.82\n"
]
}
],
"source": [
"y_pred = np.zeros(len(lr_test_predictions))\n",
"y_pred[lr_test_predictions[:, 0] > 0.5] = 1\n",
"print(\n",
" \"Accuracy for model after calibration: {:.2f}\".format(\n",
" accuracy_score(y_pred=y_pred, y_true=edge_labels_test)\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion\n",
"\n",
"This notebook demonstrated how to use Platt scaling and isotonic regression to calibrate a GraphSAGE model used for link prediction in a paper citation network. Importantly, it showed that using calibration can improve the classification model's accuracy."
]
},
{
"cell_type": "markdown",
"metadata": {
"nbsphinx": "hidden",
"tags": [
"CloudRunner"
]
},
"source": [
"