-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgnn.py
119 lines (105 loc) · 3.76 KB
/
gnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
import torch
from torch.nn import Linear
loss_fn = torch.nn.CrossEntropyLoss()
class GNN(torch.nn.Module):
def __init__(self, layer, num_nodes, num_classes):
"""
Args:
layer: pytorch-geometric graph convolution
num_nodes: int, number of nodes in the graph
num_classes: int, number of possible classes to assign to each node
"""
super().__init__()
self.conv1 = layer(num_nodes, 50)
self.conv2 = layer(50, 2)
self.classifier = Linear(2, num_classes)
def forward(self, x, edge_index):
"""
Args:
x: input features of each node
edge_index: sparse representation of edges between nodes (shape: 2 x num_edges)
Returns:
out: classifications of each node
h: embedding of each node
"""
h = self.conv1(x, edge_index)
h = h.tanh()
h = self.conv2(h, edge_index)
h = h.tanh()
out = self.classifier(h)
return out, h
def train_step(model, optimizer, data, train_mask):
"""GNN training step
Args:
model: GNN
optimizer: torch optimizer
data: torch_geometric.data.data.Data
train_mask: (n_nodes,) nd.array
Returns:
loss: float, loss value for the train step
out: node predictions
h: node embeddings
"""
model.train()
optimizer.zero_grad()
out, h = model(data.x, data.edge_index)
loss = loss_fn(out[train_mask], data.y[train_mask])
loss.backward()
optimizer.step()
return loss.detach().numpy(), out, h
def masked_accuracy(data, out, mask, verbose=True):
"""Measures graph node classification accuracy with a mask
Args:
data: torch_geometric.data.data.Data
out: nd.array
predicted labels for each node in a graph
mask: nd.array
indicator for which nodes we are calculating accuracy for (train or test)
Returns:
acc: float, % of pred labels matching data labels
"""
pred = out[mask].argmax(dim=1)
acc = int((pred==data.y[mask]).sum()) / mask.sum()
return acc
def train(
models,
modelnames,
data,
train_mask,
test_mask,
learning_rate=1e-3,
n_train_steps=1000,
n_per_eval=100,
seed=64
):
"""Train GNNs
Args:
models: list of pytorch GNN models
modelnames: list[str] of model names
data: torch_geometric.data.data.Data
train_mask: nd.array
indicator for which nodes are in the train set
test_mask: nd.array
indicator for which nodes are in the test set
Returns:
metrics: dict[metric_name : dict[model_name : list[union[float, list[float]]]]]
evaluating model metrics & last-layer embeddings every n_per_eval steps
"""
torch.manual_seed(seed)
metrics = {'loss' : {m : [] for m in modelnames},
'accuracy' : {m : [] for m in modelnames},
'embeddings' : {m : [] for m in modelnames}}
for i, (model, name) in enumerate(zip(models, modelnames)):
print(name)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for step in range(n_train_steps):
loss, out, h = train_step(model, optimizer, data, train_mask)
if step % n_per_eval == 0:
train_acc = masked_accuracy(data, out, train_mask)
test_acc = masked_accuracy(data, out, test_mask)
metrics['loss'][name].append(loss.item())
metrics['accuracy'][name].append((train_acc, test_acc))
metrics['embeddings'][name].append(h.tolist())
print(f'[{step} iter.] loss: {np.around(loss, 2)}, test acc: {np.around(test_acc, 2)}')
return metrics