Skip to content

Commit 55bf21d

Browse files
committedOct 14, 2021
Reformat using black
1 parent 2019ac4 commit 55bf21d

16 files changed

+1503
-960
lines changed
 

‎chem/data/collect_all.py

+113-28
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import os
32
import sys
43
from sklearn.metrics import f1_score
@@ -17,19 +16,23 @@
1716
from data import ClassificationData, JCIClassificationData
1817

1918
import logging
20-
logging.getLogger('pysmiles').setLevel(logging.CRITICAL)
2119

20+
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
2221

23-
class PartOfNet(pl.LightningModule):
2422

23+
class PartOfNet(pl.LightningModule):
2524
def __init__(self, in_length, loops=10):
2625
super().__init__()
27-
self.loops=loops
26+
self.loops = loops
2827
self.left_graph_net = tgnn.GATConv(in_length, in_length)
2928
self.right_graph_net = tgnn.GATConv(in_length, in_length)
3029
self.attention = nn.Linear(in_length, 1)
3130
self.global_attention = tgnn.GlobalAttention(self.attention)
32-
self.output_net = nn.Sequential(nn.Linear(2*in_length,2*in_length), nn.Linear(2*in_length,in_length), nn.Linear(in_length,500))
31+
self.output_net = nn.Sequential(
32+
nn.Linear(2 * in_length, 2 * in_length),
33+
nn.Linear(2 * in_length, in_length),
34+
nn.Linear(in_length, 500),
35+
)
3336
self.f1 = F1(1, threshold=0.5)
3437

3538
def _execute(self, batch, batch_idx):
@@ -40,60 +43,130 @@ def _execute(self, batch, batch_idx):
4043

4144
def training_step(self, *args, **kwargs):
4245
loss, f1 = self._execute(*args, **kwargs)
43-
self.log('train_loss', loss.detach().item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
44-
self.log('train_f1', f1.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
46+
self.log(
47+
"train_loss",
48+
loss.detach().item(),
49+
on_step=True,
50+
on_epoch=True,
51+
prog_bar=True,
52+
logger=True,
53+
)
54+
self.log(
55+
"train_f1",
56+
f1.item(),
57+
on_step=True,
58+
on_epoch=True,
59+
prog_bar=True,
60+
logger=True,
61+
)
4562
return loss
4663

4764
def validation_step(self, *args, **kwargs):
4865
with torch.no_grad():
4966
loss, f1 = self._execute(*args, **kwargs)
50-
self.log('val_loss', loss.detach().item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
51-
self.log('val_f1', f1.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
67+
self.log(
68+
"val_loss",
69+
loss.detach().item(),
70+
on_step=True,
71+
on_epoch=True,
72+
prog_bar=True,
73+
logger=True,
74+
)
75+
self.log(
76+
"val_f1",
77+
f1.item(),
78+
on_step=True,
79+
on_epoch=True,
80+
prog_bar=True,
81+
logger=True,
82+
)
5283
return loss
5384

5485
def forward(self, x):
5586
a = self.left_graph_net(x.x_s, x.edge_index_s.long())
5687
b = self.right_graph_net(x.x_t, x.edge_index_t.long())
57-
return self.output_net(torch.cat([self.global_attention(a, x.x_s_batch),self.global_attention(b, x.x_t_batch)], dim=1))
88+
return self.output_net(
89+
torch.cat(
90+
[
91+
self.global_attention(a, x.x_s_batch),
92+
self.global_attention(b, x.x_t_batch),
93+
],
94+
dim=1,
95+
)
96+
)
5897

5998
def configure_optimizers(self):
6099
optimizer = torch.optim.Adam(self.parameters())
61100
return optimizer
62101

63102

64103
class JCINet(pl.LightningModule):
65-
66104
def __init__(self, in_length, hidden_length, num_classes, loops=10):
67105
super().__init__()
68-
self.loops=loops
106+
self.loops = loops
69107

70-
self.node_net = nn.Sequential(nn.Linear(self.loops*in_length,hidden_length), nn.ReLU())
108+
self.node_net = nn.Sequential(
109+
nn.Linear(self.loops * in_length, hidden_length), nn.ReLU()
110+
)
71111
self.embedding = torch.nn.Embedding(800, in_length)
72112
self.left_graph_net = tgnn.GATConv(in_length, in_length, dropout=0.1)
73113
self.final_graph_net = tgnn.GATConv(in_length, hidden_length, dropout=0.1)
74114
self.attention = nn.Linear(hidden_length, 1)
75115
self.global_attention = tgnn.GlobalAttention(self.attention)
76-
self.output_net = nn.Sequential(nn.Linear(hidden_length,hidden_length), nn.Linear(hidden_length, num_classes))
116+
self.output_net = nn.Sequential(
117+
nn.Linear(hidden_length, hidden_length),
118+
nn.Linear(hidden_length, num_classes),
119+
)
77120
self.f1 = F1(num_classes, threshold=0.5)
78121

79122
def _execute(self, batch, batch_idx):
80123
pred = self(batch)
81124
labels = batch.label.float()
82125
loss = F.binary_cross_entropy_with_logits(pred, labels)
83-
f1 = f1_score(labels.cpu()>0.5, torch.sigmoid(pred).cpu()>0.5, average="micro")
126+
f1 = f1_score(
127+
labels.cpu() > 0.5, torch.sigmoid(pred).cpu() > 0.5, average="micro"
128+
)
84129
return loss, f1
85130

86131
def training_step(self, *args, **kwargs):
87132
loss, f1 = self._execute(*args, **kwargs)
88-
self.log('train_loss', loss.detach().item(), on_step=False, on_epoch=True, prog_bar=True, logger=True)
89-
self.log('train_f1', f1.item(), on_step=False, on_epoch=True, prog_bar=True, logger=True)
133+
self.log(
134+
"train_loss",
135+
loss.detach().item(),
136+
on_step=False,
137+
on_epoch=True,
138+
prog_bar=True,
139+
logger=True,
140+
)
141+
self.log(
142+
"train_f1",
143+
f1.item(),
144+
on_step=False,
145+
on_epoch=True,
146+
prog_bar=True,
147+
logger=True,
148+
)
90149
return loss
91150

92151
def validation_step(self, *args, **kwargs):
93152
with torch.no_grad():
94153
loss, f1 = self._execute(*args, **kwargs)
95-
self.log('val_loss', loss.detach().item(), on_step=False, on_epoch=True, prog_bar=True, logger=True)
96-
self.log('val_f1', f1.item(), on_step=False, on_epoch=True, prog_bar=True, logger=True)
154+
self.log(
155+
"val_loss",
156+
loss.detach().item(),
157+
on_step=False,
158+
on_epoch=True,
159+
prog_bar=True,
160+
logger=True,
161+
)
162+
self.log(
163+
"val_f1",
164+
f1.item(),
165+
on_step=False,
166+
on_epoch=True,
167+
prog_bar=True,
168+
logger=True,
169+
)
97170
return loss
98171

99172
def forward(self, x):
@@ -102,7 +175,7 @@ def forward(self, x):
102175
for _ in range(self.loops):
103176
a = self.left_graph_net(a, x.edge_index.long())
104177
l.append(a)
105-
at = self.global_attention(self.node_net(torch.cat(l,dim=1)), x.x_batch)
178+
at = self.global_attention(self.node_net(torch.cat(l, dim=1)), x.x_batch)
106179
return self.output_net(at)
107180

108181
def configure_optimizers(self):
@@ -116,28 +189,40 @@ def train(train_loader, validation_loader):
116189
else:
117190
trainer_kwargs = dict(gpus=0)
118191
net = JCINet(100, 100, 500)
119-
tb_logger = pl_loggers.CSVLogger('../../logs/')
192+
tb_logger = pl_loggers.CSVLogger("../../logs/")
120193
checkpoint_callback = ModelCheckpoint(
121194
dirpath=os.path.join(tb_logger.log_dir, "checkpoints"),
122195
filename="{epoch}-{step}-{val_loss:.7f}",
123196
save_top_k=5,
124197
save_last=True,
125198
verbose=True,
126-
monitor='val_loss',
127-
mode='min'
199+
monitor="val_loss",
200+
mode="min",
201+
)
202+
trainer = pl.Trainer(
203+
logger=tb_logger,
204+
callbacks=[checkpoint_callback],
205+
replace_sampler_ddp=False,
206+
**trainer_kwargs
128207
)
129-
trainer = pl.Trainer(logger=tb_logger, callbacks=[checkpoint_callback], replace_sampler_ddp=False, **trainer_kwargs)
130208
trainer.fit(net, train_loader, val_dataloaders=validation_loader)
131209

132210

133211
if __name__ == "__main__":
134212
batch_size = int(sys.argv[1])
135-
#vl = ClassificationData("data/full_chebi", split="validation")
136-
#tr = ClassificationData("data/full_chebi", split="train")
213+
# vl = ClassificationData("data/full_chebi", split="validation")
214+
# tr = ClassificationData("data/full_chebi", split="train")
137215
tr = JCIClassificationData("data/JCI_data", split="train")
138216
vl = JCIClassificationData("data/JCI_data", split="validation")
139217

140-
train_loader = DataLoader(tr, shuffle=True, batch_size=batch_size, follow_batch=["x", "edge_index", "label"])
141-
validation_loader = DataLoader(vl, batch_size=batch_size, follow_batch=["x", "edge_index", "label"])
218+
train_loader = DataLoader(
219+
tr,
220+
shuffle=True,
221+
batch_size=batch_size,
222+
follow_batch=["x", "edge_index", "label"],
223+
)
224+
validation_loader = DataLoader(
225+
vl, batch_size=batch_size, follow_batch=["x", "edge_index", "label"]
226+
)
142227

143228
train(train_loader, validation_loader)

‎chem/data/datasets.py

+784-677
Large diffs are not rendered by default.

‎chem/data/reader.py

+14-16
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66

77
class DataReader:
8-
98
def _get_raw_data(self, row):
109
return row[0]
1110

@@ -23,11 +22,12 @@ def _read_label(self, raw_label):
2322
return raw_label
2423

2524
def to_data(self, row):
26-
return self._read_data(self._get_raw_data(row)), self._read_label(self._get_raw_label(row))
25+
return self._read_data(self._get_raw_data(row)), self._read_label(
26+
self._get_raw_label(row)
27+
)
2728

2829

2930
class ChemDataReader(DataReader):
30-
3131
@classmethod
3232
def name(cls):
3333
return "smiles_token"
@@ -54,7 +54,6 @@ def _read_data(self, raw_data):
5454

5555

5656
class OrdReader(DataReader):
57-
5857
@classmethod
5958
def name(cls):
6059
return "ord"
@@ -64,7 +63,6 @@ def _read_data(self, raw_data):
6463

6564

6665
class MolDatareader(DataReader):
67-
6866
@classmethod
6967
def name(cls):
7068
return "mol"
@@ -73,10 +71,10 @@ def __init__(self, batch_size, **kwargs):
7371
super().__init__(batch_size, **kwargs)
7472
self.cache = []
7573

76-
77-
7874
def to_data(self, row):
79-
return self.get_encoded_mol(row[self.SMILES_INDEX], self.cache),self._get_label(row)
75+
return self.get_encoded_mol(
76+
row[self.SMILES_INDEX], self.cache
77+
), self._get_label(row)
8078

8179
def get_encoded_mol(self, smiles, cache):
8280
try:
@@ -102,25 +100,26 @@ def get_encoded_mol(self, smiles, cache):
102100

103101

104102
class GraphDataset(DataReader):
105-
106103
@classmethod
107104
def name(cls):
108105
return "graph"
109106

110107
def __init__(self, batch_size, **kwargs):
111108
super().__init__(batch_size, **kwargs)
112-
self.collater = Collater(follow_batch=["x", "edge_attr", "edge_index", "label"], exclude_keys=[])
109+
self.collater = Collater(
110+
follow_batch=["x", "edge_attr", "edge_index", "label"], exclude_keys=[]
111+
)
113112
self.cache = []
114113

115114
def process_smiles(self, smiles):
116-
117115
def cache(m):
118116
try:
119117
x = self.cache.index(m)
120118
except ValueError:
121119
x = len(self.cache)
122120
self.cache.append(m)
123121
return x
122+
124123
try:
125124
mol = ps.read_smiles(smiles)
126125
except ValueError:
@@ -150,7 +149,7 @@ def to_data(self, df):
150149
for row in df.values[:DATA_LIMIT]:
151150
d = self.process_smiles(row[self.SMILES_INDEX])
152151
if d is not None and d.num_nodes > 1:
153-
d.y = torch.tensor(row[self.LABEL_INDEX:].astype(bool)).unsqueeze(0)
152+
d.y = torch.tensor(row[self.LABEL_INDEX :].astype(bool)).unsqueeze(0)
154153
yield d
155154

156155

@@ -160,15 +159,15 @@ def to_data(self, df):
160159
pass
161160
else:
162161
from k_gnn.dataloader import collate
163-
class GraphTwoDataset(GraphDataset):
164162

163+
class GraphTwoDataset(GraphDataset):
165164
@classmethod
166165
def name(cls):
167166
return "graph_k2"
168167

169168
def to_data(self, df: pd.DataFrame):
170169
for data in super().to_data(df)[:DATA_LIMIT]:
171-
if data.num_nodes >=6:
170+
if data.num_nodes >= 6:
172171
x = data.x
173172
data.x = data.x.unsqueeze(0)
174173
data = TwoMalkin()(data)
@@ -178,9 +177,8 @@ def to_data(self, df: pd.DataFrame):
178177
def collate(self, list_of_tuples):
179178
return collate(list_of_tuples)
180179

181-
182180
class JCIExtendedGraphTwoData(JCIExtendedBase, GraphTwoDataset):
183181
pass
184182

185183
class JCIGraphTwoData(JCIBase, GraphTwoDataset):
186-
pass
184+
pass

‎chem/data/structures.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,15 @@ def __init__(self, ppd: PrePairData, graph):
2828
self.label = ppd.label
2929

3030
def __inc__(self, key, value):
31-
if key == 'edge_index_s':
31+
if key == "edge_index_s":
3232
return self.x_s.size(0)
33-
if key == 'edge_index_t':
33+
if key == "edge_index_t":
3434
return self.x_t.size(0)
3535
else:
3636
return super().__inc__(key, value)
3737

3838

3939
class XYData(torch.utils.data.Dataset, TransferableDataType):
40-
4140
def __getitem__(self, index) -> T_co:
4241
return self.x[index], self.y[index]
4342

@@ -52,7 +51,9 @@ def __init__(self, x, y, additional_fields=None, **kwargs):
5251
self.x = x
5352
self.y = y
5453

55-
self.additional_fields = list(additional_fields.keys()) if additional_fields else []
54+
self.additional_fields = (
55+
list(additional_fields.keys()) if additional_fields else []
56+
)
5657

5758
def to_x(self, device):
5859
return self.x.to(device)
@@ -63,15 +64,22 @@ def to_y(self, device):
6364
def to(self, device):
6465
x = self.to_x(device)
6566
y = self.to_y(device)
66-
return XYData(x, y, additional_fields={k: getattr(self, k) for k in self.additional_fields} )
67+
return XYData(
68+
x,
69+
y,
70+
additional_fields={k: getattr(self, k) for k in self.additional_fields},
71+
)
6772

6873

6974
class XYMolData(XYData):
70-
7175
def to_x(self, device):
7276
l = []
7377
for g in self.x:
7478
graph = g.copy()
75-
nx.set_node_attributes(graph, {k: v.to(device) for k, v in nx.get_node_attributes(g, "x").items()}, "x")
79+
nx.set_node_attributes(
80+
graph,
81+
{k: v.to(device) for k, v in nx.get_node_attributes(g, "x").items()},
82+
"x",
83+
)
7684
l.append(graph)
77-
return tuple(l)
85+
return tuple(l)

‎chem/model.py

+45-15
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytorch_lightning as pl
99
from pytorch_lightning.metrics import F1
1010

11+
1112
class ChEBIRecNN(pl.LightningModule):
1213
def __init__(self):
1314
super(ChEBIRecNN, self).__init__()
@@ -22,10 +23,7 @@ def __init__(self):
2223
self._f1 = F1(500, threshold=0.5)
2324
self._loss_fun = F.binary_cross_entropy_with_logits
2425

25-
self.metrics = {
26-
"loss": self._loss_fun,
27-
"f1": self._f1
28-
}
26+
self.metrics = {"loss": self._loss_fun, "f1": self._f1}
2927

3028
self.c1 = nn.Linear(self.length, self.length)
3129
self.c2 = nn.Linear(self.length, self.length)
@@ -34,11 +32,31 @@ def __init__(self):
3432
self.c5 = nn.Linear(self.length, self.length)
3533
self.c = {1: self.c1, 2: self.c2, 3: self.c3, 4: self.c4, 5: self.c5}
3634

37-
self.NN_single_node = nn.Sequential(nn.Linear(self.atom_enc, self.length), nn.ReLU(), nn.Linear(self.length, self.length))
38-
self.merge = nn.Sequential(nn.Linear(2*self.length, self.length), nn.ReLU(), nn.Linear(self.length, self.length))
39-
self.register_parameter("attention_weight", torch.nn.Parameter(torch.rand(self.length,1, requires_grad=True)))
40-
self.register_parameter("dag_weight", torch.nn.Parameter(torch.rand(self.length,1, requires_grad=True)))
41-
self.final = nn.Sequential(nn.Linear(self.length, self.length), nn.ReLU(), nn.Linear(self.length, self.length), nn.ReLU(), nn.Linear(self.length, self.num_of_classes))
35+
self.NN_single_node = nn.Sequential(
36+
nn.Linear(self.atom_enc, self.length),
37+
nn.ReLU(),
38+
nn.Linear(self.length, self.length),
39+
)
40+
self.merge = nn.Sequential(
41+
nn.Linear(2 * self.length, self.length),
42+
nn.ReLU(),
43+
nn.Linear(self.length, self.length),
44+
)
45+
self.register_parameter(
46+
"attention_weight",
47+
torch.nn.Parameter(torch.rand(self.length, 1, requires_grad=True)),
48+
)
49+
self.register_parameter(
50+
"dag_weight",
51+
torch.nn.Parameter(torch.rand(self.length, 1, requires_grad=True)),
52+
)
53+
self.final = nn.Sequential(
54+
nn.Linear(self.length, self.length),
55+
nn.ReLU(),
56+
nn.Linear(self.length, self.length),
57+
nn.ReLU(),
58+
nn.Linear(self.length, self.num_of_classes),
59+
)
4260

4361
def forward(self, molecules: Iterable[Molecule]):
4462
return torch.stack([self._proc_single_mol(molecule) for molecule in molecules])
@@ -60,7 +78,12 @@ def _proc_single_mol(self, molecule):
6078
output = F.relu(self.merge(inp)) + inp_prev
6179
for succ in dag.successors(node):
6280
try:
63-
inputs[succ] = torch.cat((self.c[num_inputs[succ]](inputs[succ]), output.unsqueeze(0)))
81+
inputs[succ] = torch.cat(
82+
(
83+
self.c[num_inputs[succ]](inputs[succ]),
84+
output.unsqueeze(0),
85+
)
86+
)
6487
num_inputs[succ] += 1
6588
except KeyError:
6689
inputs[succ] = output.unsqueeze(0)
@@ -87,23 +110,30 @@ def validation_step(self, batch, batch_idx):
87110
return self._calculate_metrics(prediction, labels, prefix="val_")
88111

89112
def process_atom(self, node, molecule):
90-
return F.dropout(F.relu(self.NN_single_node(molecule.get_atom_features(node).to(self.device))), p=0.1)
113+
return F.dropout(
114+
F.relu(
115+
self.NN_single_node(molecule.get_atom_features(node).to(self.device))
116+
),
117+
p=0.1,
118+
)
91119

92120
def training_epoch_end(self, outputs) -> None:
93121
for metric in self.metrics:
94-
avg = torch.stack([o[metric] for o in outputs]).mean()
122+
avg = torch.stack([o[metric] for o in outputs]).mean()
95123
self.log(metric, avg)
96124

97125
def validation_epoch_end(self, outputs) -> None:
98126
if not self.trainer.running_sanity_check:
99127
for metric in self.metrics:
100-
avg = torch.stack([o[metric] for o in outputs]).mean()
128+
avg = torch.stack([o[metric] for o in outputs]).mean()
101129
self.log("val_" + metric, avg)
102130

103131
@staticmethod
104132
def attention(weights, x):
105-
return torch.sum(torch.mul(torch.softmax(torch.matmul(x, weights), dim=0),x), dim=0)
133+
return torch.sum(
134+
torch.mul(torch.softmax(torch.matmul(x, weights), dim=0), x), dim=0
135+
)
106136

107137
def configure_optimizers(self):
108138
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
109-
return optimizer
139+
return optimizer

‎chem/models/base.py

+77-19
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import logging
1212
import sys
1313

14-
logging.getLogger('pysmiles').setLevel(logging.CRITICAL)
14+
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
1515

1616

1717
class JCIBaseNet(pl.LightningModule):
@@ -39,19 +39,59 @@ def _execute(self, batch, batch_idx):
3939

4040
def training_step(self, *args, **kwargs):
4141
loss, f1, mse = self._execute(*args, **kwargs)
42-
self.log('train_loss', loss.detach().item(), on_step=False, on_epoch=True, prog_bar=True, logger=True)
43-
self.log('train_f1', f1.detach().item(), on_step=False, on_epoch=True, prog_bar=True, logger=True)
44-
self.log('train_mse', mse.detach().item(), on_step=False, on_epoch=True,
45-
prog_bar=True, logger=True)
42+
self.log(
43+
"train_loss",
44+
loss.detach().item(),
45+
on_step=False,
46+
on_epoch=True,
47+
prog_bar=True,
48+
logger=True,
49+
)
50+
self.log(
51+
"train_f1",
52+
f1.detach().item(),
53+
on_step=False,
54+
on_epoch=True,
55+
prog_bar=True,
56+
logger=True,
57+
)
58+
self.log(
59+
"train_mse",
60+
mse.detach().item(),
61+
on_step=False,
62+
on_epoch=True,
63+
prog_bar=True,
64+
logger=True,
65+
)
4666
return loss
4767

4868
def validation_step(self, *args, **kwargs):
4969
with torch.no_grad():
5070
loss, f1, mse = self._execute(*args, **kwargs)
51-
self.log('val_loss', loss.detach().item(), on_step=False, on_epoch=True, prog_bar=True, logger=True)
52-
self.log('val_f1', f1.detach().item(), on_step=False, on_epoch=True, prog_bar=True, logger=True)
53-
self.log('val_mse', mse.detach().item(), on_step=False, on_epoch=True,
54-
prog_bar=True, logger=True)
71+
self.log(
72+
"val_loss",
73+
loss.detach().item(),
74+
on_step=False,
75+
on_epoch=True,
76+
prog_bar=True,
77+
logger=True,
78+
)
79+
self.log(
80+
"val_f1",
81+
f1.detach().item(),
82+
on_step=False,
83+
on_epoch=True,
84+
prog_bar=True,
85+
logger=True,
86+
)
87+
self.log(
88+
"val_mse",
89+
mse.detach().item(),
90+
on_step=False,
91+
on_epoch=True,
92+
prog_bar=True,
93+
logger=True,
94+
)
5595
return loss
5696

5797
def forward(self, x):
@@ -62,7 +102,14 @@ def configure_optimizers(self):
62102
return optimizer
63103

64104
@classmethod
65-
def run(cls, data, name, model_args: list = None, model_kwargs: dict = None, weighted=False):
105+
def run(
106+
cls,
107+
data,
108+
name,
109+
model_args: list = None,
110+
model_kwargs: dict = None,
111+
weighted=False,
112+
):
66113
if model_args is None:
67114
model_args = []
68115
if model_kwargs is None:
@@ -76,8 +123,10 @@ def run(cls, data, name, model_args: list = None, model_kwargs: dict = None, wei
76123
if weighted:
77124
weights = model_kwargs.get("weights")
78125
if weights is None:
79-
weights = 1 + torch.sum(torch.cat([data.y for data in train_data]).float(), dim=0)
80-
weights = torch.mean(weights)/weights
126+
weights = 1 + torch.sum(
127+
torch.cat([data.y for data in train_data]).float(), dim=0
128+
)
129+
weights = torch.mean(weights) / weights
81130
name += "__weighted"
82131
model_kwargs["weights"] = weights
83132
else:
@@ -91,25 +140,34 @@ def run(cls, data, name, model_args: list = None, model_kwargs: dict = None, wei
91140
else:
92141
trainer_kwargs = dict(gpus=0)
93142

94-
tb_logger = pl_loggers.TensorBoardLogger('logs/', name=name)
143+
tb_logger = pl_loggers.TensorBoardLogger("logs/", name=name)
95144
checkpoint_callback = ModelCheckpoint(
96145
dirpath=os.path.join(tb_logger.log_dir, "checkpoints"),
97146
filename="{epoch}-{step}-{val_loss:.7f}",
98147
save_top_k=5,
99148
save_last=True,
100149
verbose=True,
101-
monitor='val_loss',
102-
mode='min'
150+
monitor="val_loss",
151+
mode="min",
103152
)
104153

105154
# Calculate weights per class
106155

107156
net = cls(*model_args, **model_kwargs)
108157

109158
# Early stopping seems to be bugged right now with ddp accelerator :(
110-
es = EarlyStopping(monitor='val_loss', patience=10, min_delta=0.00,
111-
verbose=False,
159+
es = EarlyStopping(
160+
monitor="val_loss",
161+
patience=10,
162+
min_delta=0.00,
163+
verbose=False,
112164
)
113165

114-
trainer = pl.Trainer(logger=tb_logger,max_epochs=300, callbacks=[checkpoint_callback], replace_sampler_ddp=False, **trainer_kwargs)
115-
trainer.fit(net, train_data, val_dataloaders=val_data)
166+
trainer = pl.Trainer(
167+
logger=tb_logger,
168+
max_epochs=300,
169+
callbacks=[checkpoint_callback],
170+
replace_sampler_ddp=False,
171+
**trainer_kwargs
172+
)
173+
trainer.fit(net, train_data, val_dataloaders=val_data)

‎chem/models/chemyk.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import logging
1212
from chem.models.base import JCIBaseNet
1313

14-
logging.getLogger('pysmiles').setLevel(logging.CRITICAL)
14+
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
1515

1616

1717
class ChemYK(JCIBaseNet):
@@ -28,25 +28,36 @@ def __init__(self, in_d, out_d, num_classes, **kwargs):
2828
self.w_l = nn.Linear(d_internal, d_internal)
2929
self.w_r = nn.Linear(d_internal, d_internal)
3030
self.norm = nn.LayerNorm(d_internal)
31-
self.output = nn.Sequential(nn.Linear(in_d, in_d), nn.ReLU(), nn.Dropout(0.2), nn.Linear(in_d, num_classes))
31+
self.output = nn.Sequential(
32+
nn.Linear(in_d, in_d),
33+
nn.ReLU(),
34+
nn.Dropout(0.2),
35+
nn.Linear(in_d, num_classes),
36+
)
3237

3338
def forward(self, data, *args, **kwargs):
3439
m = self.embedding(data.x)
3540
max_width = m.shape[1]
36-
h = [m] #torch.zeros(emb.shape[0], max_width, *emb.shape[1:])
37-
#h[:, 0] = emb
41+
h = [m] # torch.zeros(emb.shape[0], max_width, *emb.shape[1:])
42+
# h[:, 0] = emb
3843
for width in range(1, max_width):
39-
l = torch.stack(tuple(h[i][:, :(max_width-width)] for i in range(width)))
40-
r = torch.stack(tuple(h[i][:,(width-i):] for i in range(0, width))).flip(0)
41-
m = self.merge(l,r)
44+
l = torch.stack(tuple(h[i][:, : (max_width - width)] for i in range(width)))
45+
r = torch.stack(
46+
tuple(h[i][:, (width - i) :] for i in range(0, width))
47+
).flip(0)
48+
m = self.merge(l, r)
4249
h.append(m)
4350
return self.output(m).squeeze(1)
4451

4552
def merge(self, l, r):
4653
x = torch.stack([self.a_l(l), self.a_r(r)])
4754
beta = torch.softmax(x, 0)
48-
return F.leaky_relu(self.attention(torch.sum(beta*torch.stack([self.w_l(l), self.w_r(r)]), dim=0)))
55+
return F.leaky_relu(
56+
self.attention(
57+
torch.sum(beta * torch.stack([self.w_l(l), self.w_r(r)]), dim=0)
58+
)
59+
)
4960

5061
def attention(self, parts):
5162
at = torch.softmax(self.s(parts), 1)
52-
return torch.sum(at*parts, dim=0)
63+
return torch.sum(at * parts, dim=0)

‎chem/models/electra.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
import logging
66
from chem.models.base import JCIBaseNet
77

8-
logging.getLogger('pysmiles').setLevel(logging.CRITICAL)
8+
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
9+
910

1011
class ElectraPre(JCIBaseNet):
1112
NAME = "Electra"
13+
1214
def __init__(self, config=None, **kwargs):
1315
super().__init__(**kwargs)
1416
config = ElectraConfig(**config)

‎chem/models/graph.py

+32-16
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from chem.models.base import JCIBaseNet
1111

12-
logging.getLogger('pysmiles').setLevel(logging.CRITICAL)
12+
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
1313

1414

1515
class JCIGraphNet(JCIBaseNet):
@@ -23,11 +23,16 @@ def __init__(self, in_length, hidden_length, num_classes, **kwargs):
2323
self.conv2 = tgnn.GraphConv(in_length, in_length)
2424
self.conv3 = tgnn.GraphConv(in_length, hidden_length)
2525

26-
self.output_net = nn.Sequential(nn.Linear(hidden_length,hidden_length), nn.ELU(), nn.Linear(hidden_length,hidden_length), nn.ELU(), nn.Linear(hidden_length, num_classes))
26+
self.output_net = nn.Sequential(
27+
nn.Linear(hidden_length, hidden_length),
28+
nn.ELU(),
29+
nn.Linear(hidden_length, hidden_length),
30+
nn.ELU(),
31+
nn.Linear(hidden_length, num_classes),
32+
)
2733

2834
self.dropout = nn.Dropout(0.1)
2935

30-
3136
def forward(self, x):
3237
a = self.embedding(x.x)
3338
a = self.dropout(a)
@@ -38,24 +43,37 @@ def forward(self, x):
3843
a = scatter_add(a, x.batch, dim=0)
3944
return self.output_net(a)
4045

46+
4147
class JCIGraphAttentionNet(JCIBaseNet):
4248
NAME = "AGNN"
4349

4450
def __init__(self, in_length, hidden_length, num_classes, **kwargs):
4551
super().__init__(num_classes, **kwargs)
4652
self.embedding = torch.nn.Embedding(800, in_length)
4753
self.edge_embedding = torch.nn.Embedding(4, in_length)
48-
in_length = in_length+10
49-
self.conv1 = tgnn.GATConv(in_length, in_length, 5, concat=False, dropout=0.1, add_self_loops=True)
50-
self.conv2 = tgnn.GATConv(in_length, in_length, 5, concat=False, add_self_loops=True)
51-
self.conv3 = tgnn.GATConv(in_length, in_length, 5, concat=False, add_self_loops=True)
52-
self.conv4 = tgnn.GATConv(in_length, in_length, 5, concat=False, add_self_loops=True)
53-
self.conv5 = tgnn.GATConv(in_length, in_length, 5, concat=False, add_self_loops=True)
54-
self.output_net = nn.Sequential(nn.Linear(in_length, hidden_length),
55-
nn.LeakyReLU(),
56-
nn.Linear(hidden_length, hidden_length),
57-
nn.LeakyReLU(),
58-
nn.Linear(hidden_length, num_classes))
54+
in_length = in_length + 10
55+
self.conv1 = tgnn.GATConv(
56+
in_length, in_length, 5, concat=False, dropout=0.1, add_self_loops=True
57+
)
58+
self.conv2 = tgnn.GATConv(
59+
in_length, in_length, 5, concat=False, add_self_loops=True
60+
)
61+
self.conv3 = tgnn.GATConv(
62+
in_length, in_length, 5, concat=False, add_self_loops=True
63+
)
64+
self.conv4 = tgnn.GATConv(
65+
in_length, in_length, 5, concat=False, add_self_loops=True
66+
)
67+
self.conv5 = tgnn.GATConv(
68+
in_length, in_length, 5, concat=False, add_self_loops=True
69+
)
70+
self.output_net = nn.Sequential(
71+
nn.Linear(in_length, hidden_length),
72+
nn.LeakyReLU(),
73+
nn.Linear(hidden_length, hidden_length),
74+
nn.LeakyReLU(),
75+
nn.Linear(hidden_length, num_classes),
76+
)
5977
self.dropout = nn.Dropout(0.1)
6078

6179
def forward(self, batch):
@@ -71,5 +89,3 @@ def forward(self, batch):
7189
a = scatter_mean(a, batch.batch, dim=0)
7290
a = self.output_net(a)
7391
return a
74-
75-

‎chem/models/graph_k2.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from chem.models.base import JCIBaseNet
1313

14-
logging.getLogger('pysmiles').setLevel(logging.CRITICAL)
14+
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
1515

1616

1717
class JCIGraphK2Net(JCIBaseNet):
@@ -29,9 +29,13 @@ def __init__(self, in_length, hidden_length, num_classes, weights=None, **kwargs
2929
self.conv2_2 = tgnn.GraphConv(in_length, in_length)
3030
self.conv2_3 = tgnn.GraphConv(in_length, hidden_length)
3131

32-
self.output_net = nn.Sequential(nn.Linear(hidden_length*2, hidden_length), nn.ELU(),
33-
nn.Linear(hidden_length, hidden_length), nn.ELU(),
34-
nn.Linear(hidden_length, num_classes))
32+
self.output_net = nn.Sequential(
33+
nn.Linear(hidden_length * 2, hidden_length),
34+
nn.ELU(),
35+
nn.Linear(hidden_length, hidden_length),
36+
nn.ELU(),
37+
nn.Linear(hidden_length, num_classes),
38+
)
3539

3640
self.dropout = nn.Dropout(0.1)
3741

@@ -54,4 +58,3 @@ def forward(self, x):
5458

5559
a = self.dropout(a)
5660
return self.output_net(a)
57-

‎chem/models/graphyk.py

+32-13
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from chem.models.base import JCIBaseNet
1010

1111

12-
logging.getLogger('pysmiles').setLevel(logging.CRITICAL)
12+
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
1313

1414

1515
class ChemYK(JCIBaseNet):
@@ -26,16 +26,36 @@ def __init__(self, in_d, out_d, num_classes, **kwargs):
2626
self.softmax = nn.Softmax()
2727
self.attention_weight = nn.Linear(in_d, in_d)
2828
self.top_level_attention_weight = nn.Linear(in_d, in_d)
29-
self.output = nn.Sequential(nn.Linear(in_d, in_d), nn.ReLU(), nn.Dropout(0.2), nn.Linear(in_d, num_classes))
29+
self.output = nn.Sequential(
30+
nn.Linear(in_d, in_d),
31+
nn.ReLU(),
32+
nn.Dropout(0.2),
33+
nn.Linear(in_d, num_classes),
34+
)
3035

3136
def forward(self, batch, max_width=5):
3237
result = []
3338
for data in batch.x:
3439
# Calculate embeddings
35-
clusters = [(frozenset({x, y}), self.merge([(self.embedding(data.nodes[x]["x"]), self.embedding(data.nodes[y]["x"]))])) for x,y in data.edges]
40+
clusters = [
41+
(
42+
frozenset({x, y}),
43+
self.merge(
44+
[
45+
(
46+
self.embedding(data.nodes[x]["x"]),
47+
self.embedding(data.nodes[y]["x"]),
48+
)
49+
]
50+
),
51+
)
52+
for x, y in data.edges
53+
]
3654
while len(clusters[0][0]) < max_width:
3755
new_clusters = dict()
38-
for (cluster_l, value_l), (cluster_r, value_r) in combinations(clusters, 2):
56+
for (cluster_l, value_l), (cluster_r, value_r) in combinations(
57+
clusters, 2
58+
):
3959
if len(cluster_l.union(cluster_r)) == len(cluster_l) + 1:
4060
u = cluster_l.union(cluster_r)
4161
new_clusters[u] = new_clusters.get(u, []) + [(value_l, value_r)]
@@ -46,22 +66,21 @@ def forward(self, batch, max_width=5):
4666
return self.output(torch.stack(result))
4767

4868
def merge(self, pairs):
49-
return sum(self.fold(self._pair_merge(x,y)) for x, y in pairs)
69+
return sum(self.fold(self._pair_merge(x, y)) for x, y in pairs)
5070

51-
def _pair_merge(self, x,y):
52-
beta = self.softmax(torch.stack([self.left(x),self.right(y)]))
53-
h2 = beta[0]*self.w_l(x) + beta[1]*self.w_r(y)
71+
def _pair_merge(self, x, y):
72+
beta = self.softmax(torch.stack([self.left(x), self.right(y)]))
73+
h2 = beta[0] * self.w_l(x) + beta[1] * self.w_r(y)
5474
return self.ff_rep(h2) + h2
5575

5676
def fold(self, h):
57-
return exp(self.attention_weight(h))*h
77+
return exp(self.attention_weight(h)) * h
5878

5979
def top_level_merge(self, clusters):
60-
t = torch.stack([c for (_,c) in clusters])
80+
t = torch.stack([c for (_, c) in clusters])
6181
sm = self.softmax(self.top_level_attention_weight(t))
62-
return torch.sum(t*sm, dim=0)
63-
82+
return torch.sum(t * sm, dim=0)
6483

6584

6685
def graphyk(graph: nx.Graph):
67-
graph.nodes()
86+
graph.nodes()

‎chem/models/lstm.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sys
66
from chem.models.base import JCIBaseNet
77

8-
logging.getLogger('pysmiles').setLevel(logging.CRITICAL)
8+
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
99

1010

1111
class ChemLSTM(JCIBaseNet):
@@ -15,7 +15,12 @@ def __init__(self, in_d, out_d, num_classes, **kwargs):
1515
super().__init__(num_classes, **kwargs)
1616
self.lstm = nn.LSTM(in_d, out_d, batch_first=True)
1717
self.embedding = nn.Embedding(800, 100)
18-
self.output = nn.Sequential(nn.Linear(out_d, in_d), nn.ReLU(), nn.Dropout(0.2), nn.Linear(in_d, num_classes))
18+
self.output = nn.Sequential(
19+
nn.Linear(out_d, in_d),
20+
nn.ReLU(),
21+
nn.Dropout(0.2),
22+
nn.Linear(in_d, num_classes),
23+
)
1924

2025
def forward(self, data):
2126
x = data.x

‎chem/models/recursive.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from chem.models.base import JCIBaseNet
77

88

9-
logging.getLogger('pysmiles').setLevel(logging.CRITICAL)
9+
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
1010

1111

1212
class Recursive(JCIBaseNet):
@@ -40,7 +40,12 @@ def __init__(self, in_d, out_d, num_classes, **kwargs):
4040

4141
self.base = torch.nn.parameter.Parameter(torch.empty((in_d,)))
4242
self.base_memory = torch.nn.parameter.Parameter(torch.empty((mem_len,)))
43-
self.output = nn.Sequential(nn.Linear(in_d, in_d), nn.ReLU(), nn.Dropout(0.2), nn.Linear(in_d, num_classes))
43+
self.output = nn.Sequential(
44+
nn.Linear(in_d, in_d),
45+
nn.ReLU(),
46+
nn.Dropout(0.2),
47+
nn.Linear(in_d, num_classes),
48+
)
4449

4550
def forward(self, batch):
4651
result = []
@@ -49,7 +54,9 @@ def forward(self, batch):
4954
c = nx.center(graph)[0]
5055
d = nx.single_source_shortest_path(graph, c)
5156
if graph.edges:
52-
digraph = nx.DiGraph((a,b) if d[a] > d[b] else (b,a) for (a,b) in graph.edges)
57+
digraph = nx.DiGraph(
58+
(a, b) if d[a] > d[b] else (b, a) for (a, b) in graph.edges
59+
)
5360
else:
5461
digraph = nx.DiGraph(graph)
5562
child_results = {}
@@ -68,19 +75,21 @@ def forward(self, batch):
6875
return torch.stack(result)
6976

7077
def merge_childen(self, child_values, x):
71-
stack = torch.stack(child_values).unsqueeze(0).transpose(1,0)
72-
att = self.children_attention(x.expand(1, stack.shape[1], -1).transpose(1, 0), stack, stack)[0]
78+
stack = torch.stack(child_values).unsqueeze(0).transpose(1, 0)
79+
att = self.children_attention(
80+
x.expand(1, stack.shape[1], -1).transpose(1, 0), stack, stack
81+
)[0]
7382
return torch.sum(att.squeeze(0), dim=0)
7483

7584
def input(self, x0, hidden):
7685

7786
x = x0.unsqueeze(0).unsqueeze(0)
78-
a = self.input_norm_1(x + self.input_attention(x,x,x)[0])
87+
a = self.input_norm_1(x + self.input_attention(x, x, x)[0])
7988
a = self.input_norm_2(a + F.relu(self.input_post(a)))
8089

8190
h0 = hidden.unsqueeze(0).unsqueeze(0)
8291
b = self.hidden_norm_1(h0 + self.input_attention(h0, h0, h0)[0])
83-
#b = self.norm(b + self.hidden_post(b))
92+
# b = self.norm(b + self.hidden_post(b))
8493

8594
c = self.merge_norm_1(b + self.merge_attention(a, b, b)[0])
8695
c = self.merge_norm_2(c + F.relu(self.merge_post(c)))

‎chem/molecule.py

+165-60
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,28 @@ def __init__(self, smile, logp=None, contract_rings=False):
3333

3434
for i in range(self.no_of_atoms):
3535
atom = m.GetAtomWithIdx(i)
36-
self.graph.add_node(i, attr_dict={"atom_features": Molecule.atom_features(atom)})
36+
self.graph.add_node(
37+
i, attr_dict={"atom_features": Molecule.atom_features(atom)}
38+
)
3739
for neighbour in atom.GetNeighbors():
3840
neighbour_idx = neighbour.GetIdx()
3941
bond = m.GetBondBetweenAtoms(i, neighbour_idx)
40-
self.graph.add_edge(i, neighbour_idx,
41-
attr_dict={"bond_features": Molecule.bond_features(bond)})
42+
self.graph.add_edge(
43+
i,
44+
neighbour_idx,
45+
attr_dict={"bond_features": Molecule.bond_features(bond)},
46+
)
4247

4348
self.create_directed_graphs()
44-
#self.create_feature_vectors()
49+
# self.create_feature_vectors()
4550

4651
def create_directed_graphs(self):
47-
'''
52+
"""
4853
:return:
49-
'''
54+
"""
5055
self.directed_graphs = np.empty(
51-
(self.no_of_atoms, self.no_of_atoms - 1, 3), dtype=int)
56+
(self.no_of_atoms, self.no_of_atoms - 1, 3), dtype=int
57+
)
5258

5359
self.dag_to_node = {}
5460

@@ -66,32 +72,34 @@ def create_directed_graphs(self):
6672
break
6773

6874
def create_feature_vectors(self):
69-
'''
75+
"""
7076
:return:
71-
'''
77+
"""
7278
# create a three dimesnional matrix I,
7379
# such that Iij is the local input vector for jth vertex in ith DAG
7480

7581
length_of_bond_features = Molecule.num_bond_features()
7682
length_of_atom_features = Molecule.num_atom_features()
7783

7884
self.local_input_vector = np.zeros(
79-
(self.no_of_atoms, self.no_of_atoms, Molecule.num_of_features()))
80-
85+
(self.no_of_atoms, self.no_of_atoms, Molecule.num_of_features())
86+
)
8187

8288
for idx in range(self.no_of_atoms):
8389
sorted_path = self.directed_graphs[idx, :, :]
8490

85-
self.local_input_vector[idx, idx, :length_of_atom_features] = \
86-
self.get_atom_features(idx)
91+
self.local_input_vector[
92+
idx, idx, :length_of_atom_features
93+
] = self.get_atom_features(idx)
8794

8895
no_of_incoming_edges = {}
8996
for i in range(self.no_of_atoms - 1):
9097
node1 = sorted_path[i, 0]
9198
node2 = sorted_path[i, 1]
9299

93-
self.local_input_vector[idx, node1, :length_of_atom_features] = \
94-
self.get_atom_features(node1)
100+
self.local_input_vector[
101+
idx, node1, :length_of_atom_features
102+
] = self.get_atom_features(node1)
95103

96104
if node2 in no_of_incoming_edges:
97105
index = no_of_incoming_edges[node2]
@@ -102,12 +110,12 @@ def create_feature_vectors(self):
102110
index = 0
103111
no_of_incoming_edges[node2] = 1
104112

105-
106-
start = length_of_atom_features + index* length_of_bond_features
113+
start = length_of_atom_features + index * length_of_bond_features
107114
end = start + length_of_bond_features
108115

109-
self.local_input_vector[idx, node2, start:end] = \
110-
self.get_bond_features(node1, node2)
116+
self.local_input_vector[idx, node2, start:end] = self.get_bond_features(
117+
node1, node2
118+
)
111119

112120
def get_cycle(self):
113121
try:
@@ -116,7 +124,12 @@ def get_cycle(self):
116124
return []
117125

118126
def collect_atom_features(self):
119-
self.af = {node_id: torch.tensor(self.graph.nodes[node_id]["attr_dict"]["atom_features"]).float() for node_id in range(self.no_of_atoms)}
127+
self.af = {
128+
node_id: torch.tensor(
129+
self.graph.nodes[node_id]["attr_dict"]["atom_features"]
130+
).float()
131+
for node_id in range(self.no_of_atoms)
132+
}
120133

121134
def get_atom_features(self, node_id):
122135
return self.af[node_id]
@@ -127,37 +140,121 @@ def get_bond_features(self, node1, node2):
127140

128141
@staticmethod
129142
def atom_features(atom):
130-
return np.array(Molecule.one_of_k_encoding_unk(atom.GetSymbol(),
131-
['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl',
132-
'Br', 'Mg', 'Na',
133-
'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K',
134-
'Tl', 'Yb',
135-
'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti',
136-
'Zn', 'H', # H?
137-
'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In',
138-
'Mn',
139-
'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) +
140-
Molecule.one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5]) +
141-
Molecule.one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) +
142-
Molecule.one_of_k_encoding_unk(atom.GetImplicitValence(),
143-
[0, 1, 2, 3, 4, 5]) + [atom.GetIsAromatic()])
143+
return np.array(
144+
Molecule.one_of_k_encoding_unk(
145+
atom.GetSymbol(),
146+
[
147+
"C",
148+
"N",
149+
"O",
150+
"S",
151+
"F",
152+
"Si",
153+
"P",
154+
"Cl",
155+
"Br",
156+
"Mg",
157+
"Na",
158+
"Ca",
159+
"Fe",
160+
"As",
161+
"Al",
162+
"I",
163+
"B",
164+
"V",
165+
"K",
166+
"Tl",
167+
"Yb",
168+
"Sb",
169+
"Sn",
170+
"Ag",
171+
"Pd",
172+
"Co",
173+
"Se",
174+
"Ti",
175+
"Zn",
176+
"H", # H?
177+
"Li",
178+
"Ge",
179+
"Cu",
180+
"Au",
181+
"Ni",
182+
"Cd",
183+
"In",
184+
"Mn",
185+
"Zr",
186+
"Cr",
187+
"Pt",
188+
"Hg",
189+
"Pb",
190+
"Unknown",
191+
],
192+
)
193+
+ Molecule.one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5])
194+
+ Molecule.one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4])
195+
+ Molecule.one_of_k_encoding_unk(
196+
atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5]
197+
)
198+
+ [atom.GetIsAromatic()]
199+
)
144200

145201
@staticmethod
146202
def atom_features_of_contract_rings(degree):
147-
return np.array(Molecule.one_of_k_encoding_unk('Unknown',
148-
['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl',
149-
'Br', 'Mg', 'Na',
150-
'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K',
151-
'Tl', 'Yb',
152-
'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti',
153-
'Zn', 'H', # H?
154-
'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In',
155-
'Mn', 'Zr',
156-
'Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) +
157-
Molecule.one_of_k_encoding(degree, [0, 1, 2, 3, 4, 5]) +
158-
Molecule.one_of_k_encoding_unk(0, [0, 1, 2, 3, 4]) +
159-
Molecule.one_of_k_encoding_unk(0, [0, 1, 2, 3, 4, 5]) +
160-
[0])
203+
return np.array(
204+
Molecule.one_of_k_encoding_unk(
205+
"Unknown",
206+
[
207+
"C",
208+
"N",
209+
"O",
210+
"S",
211+
"F",
212+
"Si",
213+
"P",
214+
"Cl",
215+
"Br",
216+
"Mg",
217+
"Na",
218+
"Ca",
219+
"Fe",
220+
"As",
221+
"Al",
222+
"I",
223+
"B",
224+
"V",
225+
"K",
226+
"Tl",
227+
"Yb",
228+
"Sb",
229+
"Sn",
230+
"Ag",
231+
"Pd",
232+
"Co",
233+
"Se",
234+
"Ti",
235+
"Zn",
236+
"H", # H?
237+
"Li",
238+
"Ge",
239+
"Cu",
240+
"Au",
241+
"Ni",
242+
"Cd",
243+
"In",
244+
"Mn",
245+
"Zr",
246+
"Cr",
247+
"Pt",
248+
"Hg",
249+
"Pb",
250+
"Unknown",
251+
],
252+
)
253+
+ Molecule.one_of_k_encoding(degree, [0, 1, 2, 3, 4, 5])
254+
+ Molecule.one_of_k_encoding_unk(0, [0, 1, 2, 3, 4])
255+
+ Molecule.one_of_k_encoding_unk(0, [0, 1, 2, 3, 4, 5])
256+
+ [0]
257+
)
161258

162259
@staticmethod
163260
def bond_features_between_contract_rings():
@@ -166,22 +263,30 @@ def bond_features_between_contract_rings():
166263
@staticmethod
167264
def bond_features(bond):
168265
bt = bond.GetBondType()
169-
return np.array([bt == Chem.rdchem.BondType.SINGLE,
170-
bt == Chem.rdchem.BondType.DOUBLE,
171-
bt == Chem.rdchem.BondType.TRIPLE,
172-
bt == Chem.rdchem.BondType.AROMATIC,
173-
bond.GetIsConjugated(),
174-
bond.IsInRing()])
266+
return np.array(
267+
[
268+
bt == Chem.rdchem.BondType.SINGLE,
269+
bt == Chem.rdchem.BondType.DOUBLE,
270+
bt == Chem.rdchem.BondType.TRIPLE,
271+
bt == Chem.rdchem.BondType.AROMATIC,
272+
bond.GetIsConjugated(),
273+
bond.IsInRing(),
274+
]
275+
)
175276

176277
@staticmethod
177278
def num_of_features():
178-
return Molecule.max_number_of_parents*Molecule.num_bond_features() + Molecule.num_atom_features()
279+
return (
280+
Molecule.max_number_of_parents * Molecule.num_bond_features()
281+
+ Molecule.num_atom_features()
282+
)
179283

180284
@staticmethod
181285
def one_of_k_encoding(x, allowable_set):
182286
if x not in allowable_set:
183287
raise Exception(
184-
"input {0} not in allowable set{1}:".format(x, allowable_set))
288+
"input {0} not in allowable set{1}:".format(x, allowable_set)
289+
)
185290
return list(map(lambda s: x == s, allowable_set))
186291

187292
@staticmethod
@@ -194,20 +299,20 @@ def one_of_k_encoding_unk(x, allowable_set):
194299
@staticmethod
195300
def num_atom_features():
196301
# Return length of feature vector using a very simple molecule.
197-
m = Chem.MolFromSmiles('CC')
302+
m = Chem.MolFromSmiles("CC")
198303
alist = m.GetAtoms()
199304
a = alist[0]
200305
return len(Molecule.atom_features(a))
201306

202307
@staticmethod
203308
def num_bond_features():
204309
# Return length of feature vector using a very simple molecule.
205-
simple_mol = Chem.MolFromSmiles('CC')
310+
simple_mol = Chem.MolFromSmiles("CC")
206311
Chem.SanitizeMol(simple_mol)
207312
return len(Molecule.bond_features(simple_mol.GetBonds()[0]))
208313

209314

210-
if __name__ == '__main__':
211-
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
315+
if __name__ == "__main__":
316+
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
212317
logging.basicConfig(level=logging.INFO, format=log_format)
213318
logger = logging.getLogger(__name__)

‎chem/run.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,34 @@
22
from chem.data import datasets as ds
33
import sys
44

5+
56
def main(batch_size):
67
exps = [
7-
(electra.ElectraPre,
8-
dict(
9-
lr=1e-4,
10-
config=dict(
11-
vocab_size=1400,
12-
max_position_embeddings=1800,
13-
num_attention_heads=8,
14-
num_hidden_layers=6,
15-
type_vocab_size=1)),
16-
(ds.PubChemFullToken,)),
8+
(
9+
electra.ElectraPre,
10+
dict(
11+
lr=1e-4,
12+
config=dict(
13+
vocab_size=1400,
14+
max_position_embeddings=1800,
15+
num_attention_heads=8,
16+
num_hidden_layers=6,
17+
type_vocab_size=1,
18+
),
19+
),
20+
(ds.PubChemFullToken,),
21+
),
1722
]
1823
for net_cls, model_kwargs, datasets in exps:
1924
for dataset in datasets:
2025
for weighted in [False]:
21-
net_cls.run(dataset(batch_size), net_cls.NAME, model_kwargs=model_kwargs, weighted=weighted)
26+
net_cls.run(
27+
dataset(batch_size),
28+
net_cls.NAME,
29+
model_kwargs=model_kwargs,
30+
weighted=weighted,
31+
)
32+
2233

2334
if __name__ == "__main__":
24-
main(int(sys.argv[1]))
35+
main(int(sys.argv[1]))

‎chem/train.py

+148-72
Original file line numberDiff line numberDiff line change
@@ -20,50 +20,54 @@
2020
NUM_EPOCHS = 100
2121
LEARNING_RATE = 0.01
2222

23+
2324
def eval_model(model, dataset, test_labels):
24-
raw_values = []
25-
predictions = []
26-
final_scores = []
25+
raw_values = []
26+
predictions = []
27+
final_scores = []
2728

28-
with torch.no_grad():
29-
for batch in dataset:
30-
for molecule, label in batch:
31-
model_outputs = model(molecule)
32-
prediction = [1.0 if i > 0.5 else 0.0 for i in model_outputs]
33-
predictions.append(prediction)
34-
raw_values.append(model_outputs)
35-
final_scores.append(f1_score(prediction, label.tolist()))
29+
with torch.no_grad():
30+
for batch in dataset:
31+
for molecule, label in batch:
32+
model_outputs = model(molecule)
33+
prediction = [1.0 if i > 0.5 else 0.0 for i in model_outputs]
34+
predictions.append(prediction)
35+
raw_values.append(model_outputs)
36+
final_scores.append(f1_score(prediction, label.tolist()))
3637

37-
avg_f1 = sum(final_scores) / len(final_scores)
38-
return raw_values, predictions, final_scores, avg_f1
38+
avg_f1 = sum(final_scores) / len(final_scores)
39+
return raw_values, predictions, final_scores, avg_f1
3940

4041

4142
def crawl_info(DAG, sink_parents):
42-
topological_order = [int(i[0]) for i in DAG]
43-
target_nodes = [int(i[1]) for i in DAG]
44-
sink = target_nodes[-1]
45-
sources = []
46-
parents = {}
43+
topological_order = [int(i[0]) for i in DAG]
44+
target_nodes = [int(i[1]) for i in DAG]
45+
sink = target_nodes[-1]
46+
sources = []
47+
parents = {}
48+
49+
for i in range(len(topological_order)):
50+
for j in range(len(target_nodes)):
51+
if topological_order[i] == target_nodes[j]:
52+
if topological_order[i] not in parents.keys():
53+
parents[topological_order[i]] = []
54+
parents[topological_order[i]].append(topological_order[j])
4755

48-
for i in range(len(topological_order)):
49-
for j in range(len(target_nodes)):
50-
if topological_order[i] == target_nodes[j]:
51-
if topological_order[i] not in parents.keys():
52-
parents[topological_order[i]] = []
53-
parents[topological_order[i]].append(topological_order[j])
56+
for node in topological_order:
57+
if node not in parents.keys():
58+
sources.append(node)
5459

55-
for node in topological_order:
56-
if node not in parents.keys():
57-
sources.append(node)
60+
return topological_order, sources, parents, sink, sink_parents
5861

59-
return topological_order, sources, parents, sink, sink_parents
6062

6163
import random
6264

65+
6366
def collate(batch):
6467
input, labels = zip(*batch)
6568
return input, torch.stack(labels)
6669

70+
6771
def _execute(model, loss_fn, optimizer, data, device, with_grad=True):
6872
train_running_loss = 0
6973
data_size = 0
@@ -77,37 +81,58 @@ def _execute(model, loss_fn, optimizer, data, device, with_grad=True):
7781
prediction = model(molecules)
7882
loss = loss_fn(prediction, labels)
7983
data_size += 1
80-
f1 += f1_score(prediction > 0.5, labels > 0.5, average='micro')
84+
f1 += f1_score(prediction > 0.5, labels > 0.5, average="micro")
8185
train_running_loss += loss.item()
8286
if with_grad:
8387
print(f"Batch {num_batch}/{num_batches}")
8488
loss.backward()
8589
optimizer.step()
86-
return train_running_loss/data_size, f1/data_size
90+
return train_running_loss / data_size, f1 / data_size
91+
8792

88-
def execute_network(model, loss_fn, optimizer, train_data, validation_data, epochs, device):
93+
def execute_network(
94+
model, loss_fn, optimizer, train_data, validation_data, epochs, device
95+
):
8996
model.to(device)
9097
model.device = device
9198

9299
for name, param in model.named_parameters():
93100
if param.requires_grad:
94101
print(name)
95102

96-
columns_name=['epoch', 'train_running_loss', 'train_running_f1', 'eval_running_loss', 'eval_running_f1']
97-
with open(r'../loss_f1_training_validation.csv', 'w') as f:
103+
columns_name = [
104+
"epoch",
105+
"train_running_loss",
106+
"train_running_f1",
107+
"eval_running_loss",
108+
"eval_running_f1",
109+
]
110+
with open(r"../loss_f1_training_validation.csv", "w") as f:
98111
writer = csv.writer(f)
99112
writer.writerow(columns_name)
100113

101114
for epoch in range(epochs):
102-
train_running_loss, train_running_f1 = _execute(model, loss_fn, optimizer, train_data, device, with_grad=True)
115+
train_running_loss, train_running_f1 = _execute(
116+
model, loss_fn, optimizer, train_data, device, with_grad=True
117+
)
103118

104119
with torch.no_grad():
105-
eval_running_loss, eval_running_f1 = _execute(model, loss_fn, optimizer, validation_data, device, with_grad=False)
120+
eval_running_loss, eval_running_f1 = _execute(
121+
model, loss_fn, optimizer, validation_data, device, with_grad=False
122+
)
106123
print(
107-
f'Epoch {epoch}: loss={train_running_loss:.5f}, f1={train_running_f1:.5f}, val_loss={eval_running_loss:.5f}, val_f1={eval_running_f1:.5f}'.format(
108-
epoch, train_running_f1))
109-
fields=[epoch, train_running_loss, train_running_f1, eval_running_loss, eval_running_f1]
110-
with open(r'../loss_f1_training_validation.csv', 'a') as f:
124+
f"Epoch {epoch}: loss={train_running_loss:.5f}, f1={train_running_f1:.5f}, val_loss={eval_running_loss:.5f}, val_f1={eval_running_f1:.5f}".format(
125+
epoch, train_running_f1
126+
)
127+
)
128+
fields = [
129+
epoch,
130+
train_running_loss,
131+
train_running_f1,
132+
eval_running_loss,
133+
eval_running_f1,
134+
]
135+
with open(r"../loss_f1_training_validation.csv", "a") as f:
111136
writer = csv.writer(f)
112137
writer.writerow(fields)
113138

@@ -120,75 +145,97 @@ def prepare_data(infile):
120145
data_frame.reset_index(drop=True, inplace=True)
121146

122147
data_classes = list(data_frame.columns)
123-
data_classes.remove('MOLECULEID')
124-
data_classes.remove('SMILES')
148+
data_classes.remove("MOLECULEID")
149+
data_classes.remove("SMILES")
125150

126151
for col in data_classes:
127152
data_frame[col] = data_frame[col].astype(int)
128153

129154
train_data = []
130155
for index, row in data_frame.iterrows():
131-
train_data.append([
132-
data_frame.iloc[index].values[1],
133-
data_frame.iloc[index].values[2:502].tolist()
134-
])
135-
136-
train_df = pd.DataFrame(train_data, columns=['SMILES', 'LABELS'])
156+
train_data.append(
157+
[
158+
data_frame.iloc[index].values[1],
159+
data_frame.iloc[index].values[2:502].tolist(),
160+
]
161+
)
162+
163+
train_df = pd.DataFrame(train_data, columns=["SMILES", "LABELS"])
137164
return train_df
138165

139166

140167
def batchify(x, y):
141-
data = list(zip(x,y))
142-
return [data[i*BATCH_SIZE:(i+1)*BATCH_SIZE] for i in range(1 + len(data)//BATCH_SIZE)]
168+
data = list(zip(x, y))
169+
return [
170+
data[i * BATCH_SIZE : (i + 1) * BATCH_SIZE]
171+
for i in range(1 + len(data) // BATCH_SIZE)
172+
]
173+
143174

144175
def load_data():
145176
fpath = "data/full.pickle"
146177
if os.path.isfile(fpath):
147178
with open(fpath, "rb") as f:
148-
train_dataset, train_actual_labels, validation_dataset, validation_actual_labels = pickle.load(f)
179+
(
180+
train_dataset,
181+
train_actual_labels,
182+
validation_dataset,
183+
validation_actual_labels,
184+
) = pickle.load(f)
149185
else:
150-
print('reading data from files!')
151-
train_infile = open('../data/JCI_graph/raw/train.pkl', 'rb')
152-
test_infile = open('../data/JCI_graph/raw/test.pkl', 'rb')
153-
validation_infile = open('../data/JCI_graph/raw/validation.pkl', 'rb')
186+
print("reading data from files!")
187+
train_infile = open("../data/JCI_graph/raw/train.pkl", "rb")
188+
test_infile = open("../data/JCI_graph/raw/test.pkl", "rb")
189+
validation_infile = open("../data/JCI_graph/raw/validation.pkl", "rb")
154190

155-
#test_data = prepare_data(test_infile)
191+
# test_data = prepare_data(test_infile)
156192

157-
print('prepare train data!')
193+
print("prepare train data!")
158194
train_dataset = []
159195
train_actual_labels = []
160196

161197
for index, row in prepare_data(train_infile).iterrows():
162198
try:
163-
mol = Molecule(row['SMILES'], True)
199+
mol = Molecule(row["SMILES"], True)
164200

165201
DAGs_meta_info = mol.dag_to_node
166202
train_dataset.append(mol)
167-
train_actual_labels.append(torch.tensor(row['LABELS']).float())
203+
train_actual_labels.append(torch.tensor(row["LABELS"]).float())
168204
except:
169205
pass
170206

171-
172-
print('prepare validation data!')
207+
print("prepare validation data!")
173208
validation_dataset = []
174209
validation_actual_labels = []
175210

176-
177211
for index, row in prepare_data(validation_infile).iterrows():
178212
try:
179-
mol = Molecule(row['SMILES'], True)
213+
mol = Molecule(row["SMILES"], True)
180214

181215
DAGs_meta_info = mol.dag_to_node
182216

183217
validation_dataset.append(mol)
184-
validation_actual_labels.append(torch.tensor(row['LABELS']).float())
218+
validation_actual_labels.append(torch.tensor(row["LABELS"]).float())
185219
except:
186-
pass
220+
pass
187221

188222
with open(fpath, "wb") as f:
189-
pickle.dump((train_dataset, train_actual_labels, validation_dataset, validation_actual_labels), f)
190-
191-
return train_dataset, train_actual_labels, validation_dataset, validation_actual_labels
223+
pickle.dump(
224+
(
225+
train_dataset,
226+
train_actual_labels,
227+
validation_dataset,
228+
validation_actual_labels,
229+
),
230+
f,
231+
)
232+
233+
return (
234+
train_dataset,
235+
train_actual_labels,
236+
validation_dataset,
237+
validation_actual_labels,
238+
)
192239

193240

194241
def move_molecule(m):
@@ -204,14 +251,43 @@ def move_molecule(m):
204251
accelerator = None
205252
trainer_kwargs = dict()
206253

207-
train_dataset, train_actual_labels, validation_dataset, validation_actual_labels = load_data()
208-
train_data = data.DataLoader(list(zip(map(move_molecule, train_dataset), [l.float() for l in train_actual_labels])), batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)
209-
validation_data = data.DataLoader(list(zip(map(move_molecule, validation_dataset), [l.float() for l in validation_actual_labels])), batch_size=BATCH_SIZE, collate_fn=collate)
254+
(
255+
train_dataset,
256+
train_actual_labels,
257+
validation_dataset,
258+
validation_actual_labels,
259+
) = load_data()
260+
train_data = data.DataLoader(
261+
list(
262+
zip(
263+
map(move_molecule, train_dataset),
264+
[l.float() for l in train_actual_labels],
265+
)
266+
),
267+
batch_size=BATCH_SIZE,
268+
shuffle=True,
269+
collate_fn=collate,
270+
)
271+
validation_data = data.DataLoader(
272+
list(
273+
zip(
274+
map(move_molecule, validation_dataset),
275+
[l.float() for l in validation_actual_labels],
276+
)
277+
),
278+
batch_size=BATCH_SIZE,
279+
collate_fn=collate,
280+
)
210281

211282
model = ChEBIRecNN()
212283

213-
tb_logger = pl_loggers.CSVLogger('../logs/')
214-
trainer = pl.Trainer(logger=tb_logger, accelerator=accelerator, max_epochs=NUM_EPOCHS, **trainer_kwargs)
284+
tb_logger = pl_loggers.CSVLogger("../logs/")
285+
trainer = pl.Trainer(
286+
logger=tb_logger,
287+
accelerator=accelerator,
288+
max_epochs=NUM_EPOCHS,
289+
**trainer_kwargs,
290+
)
215291
trainer.fit(model, train_data, val_dataloaders=validation_data)
216292

217293
"""

0 commit comments

Comments
 (0)
Please sign in to comment.