Load the dataset

The Cora dataset is a homogeneous network where all nodes are papers and edges between nodes are citation links, e.g. paper A cites paper B.

[4]:
dataset = datasets.Cora()
display(HTML(dataset.description))
graph, _ = dataset.load(largest_connected_component_only=True, str_node_ids=True)
The Cora dataset consists of 2708 scientific publications classified into one of seven classes. The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words.
[5]:
print(graph.info())
StellarGraph: Undirected multigraph
 Nodes: 2485, Edges: 5209

 Node types:
  paper: [2485]
    Features: float32 vector, length 1433
    Edge types: paper-cites->paper

 Edge types:
    paper-cites->paper: [5209]

Construct splits of the input data

We have to carefully split the data to avoid data leakage and evaluate the algorithms correctly:

  • For computing node embeddings, a Train Graph (graph_train)
  • For training classifiers, a classifier Training Set (examples_train) of positive and negative edges that weren’t used for computing node embeddings
  • For choosing the best classifier, an Model Selection Test Set (examples_model_selection) of positive and negative edges that weren’t used for computing node embeddings or training the classifier
  • For the final evaluation, a Test Graph (graph_test) to compute test node embeddings with more edges than the Train Graph, and a Test Set (examples_test) of positive and negative edges not used for neither computing the test node embeddings or for classifier training or model selection

Test Graph

We begin with the full graph and use the EdgeSplitter class to produce:

  • Test Graph
  • Test set of positive/negative link examples

The Test Graph is the reduced graph we obtain from removing the test set of links from the full graph.

[6]:
# Define an edge splitter on the original graph:
edge_splitter_test = EdgeSplitter(graph)

# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from graph, and obtain the
# reduced graph graph_test with the sampled links removed:
graph_test, examples_test, labels_test = edge_splitter_test.train_test_split(
    p=0.1, method="global"
)

print(graph_test.info())
** Sampled 520 positive and 520 negative edges. **
StellarGraph: Undirected multigraph
 Nodes: 2485, Edges: 4689

 Node types:
  paper: [2485]
    Features: float32 vector, length 1433
    Edge types: paper-cites->paper

 Edge types:
    paper-cites->paper: [4689]

Train Graph

This time, we use the EdgeSplitter on the Test Graph, and perform a train/test split on the examples to produce:

  • Train Graph
  • Training set of link examples
  • Set of link examples for model selection
[7]:
# Do the same process to compute a training subset from within the test graph
edge_splitter_train = EdgeSplitter(graph_test, graph)
graph_train, examples, labels = edge_splitter_train.train_test_split(
    p=0.1, method="global"
)
(
    examples_train,
    examples_model_selection,
    labels_train,
    labels_model_selection,
) = train_test_split(examples, labels, train_size=0.75, test_size=0.25)

print(graph_train.info())
** Sampled 468 positive and 468 negative edges. **
StellarGraph: Undirected multigraph
 Nodes: 2485, Edges: 4221

 Node types:
  paper: [2485]
    Features: float32 vector, length 1433
    Edge types: paper-cites->paper

 Edge types:
    paper-cites->paper: [4221]

Below is a summary of the different splits that have been created in this section

[8]:
pd.DataFrame(
    [
        (
            "Training Set",
            len(examples_train),
            "Train Graph",
            "Test Graph",
            "Train the Link Classifier",
        ),
        (
            "Model Selection",
            len(examples_model_selection),
            "Train Graph",
            "Test Graph",
            "Select the best Link Classifier model",
        ),
        (
            "Test set",
            len(examples_test),
            "Test Graph",
            "Full Graph",
            "Evaluate the best Link Classifier",
        ),
    ],
    columns=("Split", "Number of Examples", "Hidden from", "Picked from", "Use"),
).set_index("Split")
[8]:
Number of Examples Hidden from Picked from Use
Split
Training Set 702 Train Graph Test Graph Train the Link Classifier
Model Selection 234 Train Graph Test Graph Select the best Link Classifier model
Test set 1040 Test Graph Full Graph Evaluate the best Link Classifier

Node2Vec

We use Node2Vec [1], to calculate node embeddings. These embeddings are learned in such a way to ensure that nodes that are close in the graph remain close in the embedding space. Node2Vec first involves running random walks on the graph to obtain our context pairs, and using these to train a Word2Vec model.

These are the set of parameters we can use:

  • p - Random walk parameter “p”
  • q - Random walk parameter “q”
  • dimensions - Dimensionality of node2vec embeddings
  • num_walks - Number of walks from each node
  • walk_length - Length of each random walk
  • window_size - Context window size for Word2Vec
  • num_iter - number of SGD iterations (epochs)
  • workers - Number of workers for Word2Vec
[9]:
p = 1.0
q = 1.0
dimensions = 128
num_walks = 10
walk_length = 80
window_size = 10
num_iter = 1
workers = multiprocessing.cpu_count()
[10]:
from stellargraph.data import BiasedRandomWalk
from gensim.models import Word2Vec


def node2vec_embedding(graph, name):
    rw = BiasedRandomWalk(graph)
    walks = rw.run(graph.nodes(), n=num_walks, length=walk_length, p=p, q=q)
    print(f"Number of random walks for '{name}': {len(walks)}")

    model = Word2Vec(
        walks,
        size=dimensions,
        window=window_size,
        min_count=0,
        sg=1,
        workers=workers,
        iter=num_iter,
    )

    def get_embedding(u):
        return model.wv[u]

    return get_embedding
[11]:
embedding_train = node2vec_embedding(graph_train, "Train Graph")
Number of random walks for 'Train Graph': 24850