-
Notifications
You must be signed in to change notification settings - Fork 0
/
modules.py
302 lines (238 loc) · 10.8 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
from collections import OrderedDict
import torch
import numpy as np
import torch_geometric
"""
This file includes basic model for solving integer programming. Note that PreNormLayer is still unused.
"""
class PreNormException(Exception):
pass
class PreNormLayer(torch.nn.Module):
def __init__(self, n_units, shift=True, scale=True):
super().__init__()
assert shift or scale
self.register_buffer('shift', torch.zeros(n_units) if shift else None)
self.register_buffer('scale', torch.ones(n_units) if scale else None)
self.n_units = n_units
self.waiting_updates = False
self.received_updates = False
def forward(self, input_):
if self.waiting_updates:
self.update_stats(input_)
self.received_updates = True
raise PreNormException
if self.shift is not None:
input_ = input_ + self.shift
if self.scale is not None:
input_ = input_ * self.scale
return input_
def start_updates(self):
self.avg = 0
self.var = 0
self.m2 = 0
self.count = 0
self.waiting_updates = True
self.received_updates = False
def update_stats(self, input_):
"""
Online mean and variance estimation. See: Chan et al. (1979) Updating
Formulae and a Pairwise Algorithm for Computing Sample Variances.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
"""
assert self.n_units == 1 or input_.shape[
-1] == self.n_units, f"Expected input dimension of size {self.n_units}, got {input_.shape[-1]}."
input_ = input_.reshape(-1, self.n_units)
sample_avg = input_.mean(dim=0)
sample_var = (input_ - sample_avg).pow(2).mean(dim=0)
sample_count = np.prod(input_.size()) / self.n_units
delta = sample_avg - self.avg
self.m2 = self.var * self.count + sample_var * sample_count + delta ** 2 * self.count * sample_count / (
self.count + sample_count)
self.count += sample_count
self.avg += delta * sample_count / self.count
self.var = self.m2 / self.count if self.count > 0 else 1
def stop_updates(self):
"""
Ends pre-training for that layer, and fixes the layers's parameters.
"""
assert self.count > 0
if self.shift is not None:
self.shift = -self.avg
if self.scale is not None:
self.var[self.var < 1e-8] = 1
self.scale = 1 / torch.sqrt(self.var)
del self.avg, self.var, self.m2, self.count
self.waiting_updates = False
self.trainable = False
class SelectiveNet(torch.nn.Module):
"""
Selective network to decide whether variables need to be predicted.
"""
def __init__(self, feature_size):
super().__init__()
self.select_module = torch.nn.Sequential(
torch.nn.Linear(feature_size, feature_size),
torch.nn.ReLU(),
torch.nn.Linear(feature_size, 1, bias=False),
torch.nn.Sigmoid()
)
def forward(self, feature):
return self.select_module(feature)
class GCNMLPModule(torch.nn.Module):
"""
Reproducde Z = Af(U) and Z_tilde in Neural Diving (Nair et al. 2019)
"""
def __init__(self, cons_input_size=64, vars_input_size=64, output_size=64):
super().__init__()
self.cons_mlp_layer = torch.nn.Sequential(
torch.nn.Linear(cons_input_size, output_size),
torch.nn.ReLU(),
torch.nn.Linear(output_size, output_size),
torch.nn.ReLU()
)
self.vars_mlp_layer = torch.nn.Sequential(
torch.nn.Linear(vars_input_size, output_size),
torch.nn.ReLU(),
torch.nn.Linear(output_size, output_size),
torch.nn.ReLU()
)
self.jump_vars_mlp_layer = torch.nn.Sequential(
torch.nn.Linear(output_size * 2, output_size),
torch.nn.Sigmoid()
)
self.jump_cons_mlp_layer = torch.nn.Sequential(
torch.nn.Linear(output_size * 2, output_size),
torch.nn.Sigmoid()
)
self.vars_layer_norm = torch.nn.LayerNorm(output_size)
self.cons_layer_norm = torch.nn.LayerNorm(output_size)
def jump_network_forward(self, next_z, z_tilde):
next_cons_fea, next_vars_fea = next_z
tilde_cons_fea, tilde_vars_fea = z_tilde
# concat
next_tilde_cons_fea = self.jump_cons_mlp_layer(torch.concat([next_cons_fea, tilde_cons_fea], dim=1))
next_tilde_vars_fea = self.jump_vars_mlp_layer(torch.concat([next_vars_fea, tilde_vars_fea], dim=1))
next_z_tilde = (next_tilde_cons_fea, next_tilde_vars_fea)
return next_z_tilde
def forward(self, z, z_tilde, edge_indices, edge_attrs):
constraint_features, variable_features = z
A = torch.sparse_coo_tensor(
edge_indices,
edge_attrs.squeeze(),
size=(constraint_features.shape[0], variable_features.shape[0])
).detach()
A_T = A.transpose(0, 1).detach()
mlp_cons_fea = self.cons_mlp_layer(constraint_features)
mlp_vars_fea = self.vars_mlp_layer(variable_features)
# graph convolution Z = Af(Z)
gc_cons_fea = A.mm(mlp_vars_fea) + constraint_features
gc_vars_fea = A_T.mm(mlp_cons_fea) + variable_features
# layer normalization
ln_cons_fea = self.cons_layer_norm(gc_cons_fea)
ln_vars_fea = self.vars_layer_norm(gc_vars_fea)
# jump network
next_z = (ln_cons_fea, ln_vars_fea)
next_z_tilde = self.jump_network_forward(next_z, z_tilde)
return next_z, next_z_tilde
class BipartiteGraphConvolution(torch_geometric.nn.MessagePassing):
def __init__(self, emb_size=64):
super().__init__('add')
self.feature_module_left = torch.nn.Sequential(
torch.nn.Linear(emb_size, emb_size)
)
self.feature_module_edge = torch.nn.Sequential(
torch.nn.Linear(1, emb_size, bias=False)
)
self.feature_module_right = torch.nn.Sequential(
torch.nn.Linear(emb_size, emb_size, bias=False)
)
self.feature_module_final = torch.nn.Sequential(
PreNormLayer(1, shift=False),
torch.nn.ReLU(),
torch.nn.Linear(emb_size, emb_size)
)
self.post_conv_module = torch.nn.Sequential(
PreNormLayer(1, shift=False)
)
# output_layers
self.output_module = torch.nn.Sequential(
torch.nn.Linear(2 * emb_size, emb_size),
torch.nn.ReLU(),
torch.nn.Linear(emb_size, emb_size),
)
def forward(self, left_features, edge_indices, edge_features, right_features):
output = self.propagate(edge_indices, size=(left_features.shape[0], right_features.shape[0]),
node_features=(left_features, right_features), edge_features=edge_features)
return self.output_module(torch.cat([self.post_conv_module(output), right_features], dim=-1))
def message(self, node_features_i, node_features_j, edge_features):
output = self.feature_module_final(self.feature_module_left(node_features_i)
+ self.feature_module_edge(edge_features)
+ self.feature_module_right(node_features_j))
return output
class GNNModel(torch.nn.Module):
def __init__(self, emb_size=64, cons_nfeats=4, edge_nfeats=1, var_nfeats=13, gcn_mlp_layer_num=5):
super().__init__()
self.gcn_mlp_layer_num = gcn_mlp_layer_num
# CONSTRAINT EMBEDDING
self.cons_embedding = torch.nn.Sequential(
PreNormLayer(cons_nfeats),
torch.nn.Linear(cons_nfeats, emb_size),
torch.nn.ReLU(),
torch.nn.Linear(emb_size, emb_size),
torch.nn.ReLU(),
)
# EDGE EMBEDDING
self.edge_embedding = torch.nn.Sequential(
PreNormLayer(edge_nfeats),
)
# VARIABLE EMBEDDING
self.var_embedding = torch.nn.Sequential(
PreNormLayer(var_nfeats),
torch.nn.Linear(var_nfeats, emb_size),
torch.nn.ReLU(),
torch.nn.Linear(emb_size, emb_size),
torch.nn.ReLU(),
)
self.conv_v_to_c = BipartiteGraphConvolution(emb_size)
self.conv_c_to_v = BipartiteGraphConvolution(emb_size)
self.gcn_mlp_layer = torch.nn.ModuleList(
[GCNMLPModule(emb_size, emb_size, emb_size) for _ in range(gcn_mlp_layer_num)])
def forward(self, constraint_features, edge_indices, edge_attrs, variable_features):
# edge_indices(2,N)
reversed_edge_indices = torch.stack([edge_indices[1], edge_indices[0]], dim=0)
constraint_features = self.cons_embedding(constraint_features.float())
edge_features = self.edge_embedding(edge_attrs)
variable_features = self.var_embedding(variable_features)
constraint_features = self.conv_v_to_c(variable_features, reversed_edge_indices, edge_features,
constraint_features)
variable_features = self.conv_c_to_v(constraint_features, edge_indices, edge_features, variable_features)
z = (constraint_features, variable_features)
z_tlide = (constraint_features, variable_features)
for module in self.gcn_mlp_layer:
next_z, next_z_tlide = module(z, z_tlide, edge_indices, edge_attrs)
z = next_z
z_tlide = next_z_tlide
constraint_features, variable_features = z_tlide
return variable_features
class QuickGELU(torch.nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(torch.nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = torch.nn.MultiheadAttention(d_model, n_head, batch_first=True)
self.ln_1 = torch.nn.LayerNorm(d_model)
self.mlp = torch.nn.Sequential(OrderedDict([
("c_fc", torch.nn.Linear(d_model, d_model)),
("gelu", QuickGELU()),
("c_proj", torch.nn.Linear(d_model, d_model))
]))
self.ln_2 = torch.nn.LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask, key_padding_mask=key_padding_mask)[0]
def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
x = x + self.attention(self.ln_1(x), key_padding_mask)
x = x + self.mlp(self.ln_2(x))
return x