Skip to content

Commit 4ba621d

Browse files
authored
Merge pull request #2 from ChEB-AI/dev
pull changes into fork
2 parents 3833a6e + 0ccbc27 commit 4ba621d

File tree

6 files changed

+48
-58
lines changed

6 files changed

+48
-58
lines changed

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,20 @@ ChEBai is a deep learning library that allows the combination of deep learning
66
## Pretraining
77

88
```
9-
python -m chebai fit --config=[path-to-your-config] --trainer.callbacks=configs/training/default_callbacks.yml
9+
fit --data.class_path=chebai.preprocessing.datasets.pubchem.SWJChem --model=configs/model/electra-for-pretraining.ElectraPre --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --trainer=configs/training/default_trainer.yml --trainer.callbacks=configs/training/default_callbacks.yml
1010
```
1111

1212
## Structure-based ontology extension
1313

1414
```
15-
python -m chebai train --config=[path-to-your-electra_chebi100-config] --trainer.callbacks=configs/training/default_callbacks.yml --model.pretrained_checkpoint=[path-to-pretrained-model] --model.load_prefix=generator.
15+
python -m chebai fit --config=[path-to-your-electra_chebi100-config] --trainer.callbacks=configs/training/default_callbacks.yml --model.pretrained_checkpoint=[path-to-pretrained-model] --model.load_prefix=generator.
1616
```
1717

1818

1919
## Fine-tuning for Toxicity prediction
2020

2121
```
22-
python -m chebai train --config=[path-to-your-tox21-config] --trainer.callbacks=configs/training/default_callbacks.yml --model.pretrained_checkpoint=[path-to-pretrained-model] --model.load_prefix=generator.
22+
python -m chebai fit --config=[path-to-your-tox21-config] --trainer.callbacks=configs/training/default_callbacks.yml --model.pretrained_checkpoint=[path-to-pretrained-model] --model.load_prefix=generator.
2323
```
2424

2525
```

chebai/loss/pretraining.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
3+
class ElectraPreLoss(torch.nn.Module):
4+
def __init__(self):
5+
super().__init__()
6+
self.ce = torch.nn.CrossEntropyLoss()
7+
8+
def forward(self, input, target, **loss_kwargs):
9+
t, p = input
10+
gen_pred, disc_pred = t
11+
gen_tar, disc_tar = p
12+
gen_loss = self.ce(target=torch.argmax(gen_tar.int(), dim=-1), input=gen_pred)
13+
disc_loss = self.ce(
14+
target=torch.argmax(disc_tar.int(), dim=-1), input=disc_pred
15+
)
16+
return gen_loss + disc_loss
17+

chebai/models/base.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,13 @@ def __init_subclass__(cls, **kwargs):
4141
def _get_prediction_and_labels(self, data, labels, output):
4242
return output, labels
4343

44+
def _process_labels_in_batch(self, batch):
45+
return batch.y.float()
46+
4447
def _process_batch(self, batch, batch_idx):
4548
return dict(
4649
features=batch.x,
47-
labels=batch.y.float(),
50+
labels=self._process_labels_in_batch(batch),
4851
model_kwargs=batch.additional_fields["model_kwargs"],
4952
loss_kwargs=batch.additional_fields["loss_kwargs"],
5053
idents=batch.additional_fields["idents"],

chebai/models/electra.py

+4-23
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,15 @@ def __init__(self, config=None, **kwargs):
4747
def as_pretrained(self):
4848
return self.discriminator
4949

50-
def _process_batch(self, batch, batch_idx):
51-
return dict(features=batch.x, labels=None, mask=batch.mask)
50+
def _process_labels_in_batch(self, batch):
51+
return None
5252

53-
def forward(self, data):
53+
def forward(self, data, **kwargs):
5454
features = data["features"]
5555
self.batch_size = batch_size = features.shape[0]
5656
max_seq_len = features.shape[1]
5757

58-
mask = data["mask"]
58+
mask = kwargs["mask"]
5959
with torch.no_grad():
6060
dis_tar = (
6161
torch.rand((batch_size,), device=self.device) * torch.sum(mask, dim=-1)
@@ -96,25 +96,6 @@ def forward(self, data):
9696
def _get_prediction_and_labels(self, batch, labels, output):
9797
return torch.softmax(output[0][1], dim=-1), output[1][1].int()
9898

99-
def _get_data_for_loss(self, model_output, labels):
100-
return dict(input=model_output, target=None)
101-
102-
103-
class ElectraPreLoss(torch.nn.Module):
104-
def __init__(self):
105-
super().__init__()
106-
self.ce = torch.nn.CrossEntropyLoss()
107-
108-
def forward(self, input, target):
109-
t, p = input
110-
gen_pred, disc_pred = t
111-
gen_tar, disc_tar = p
112-
gen_loss = self.ce(target=torch.argmax(gen_tar.int(), dim=-1), input=gen_pred)
113-
disc_loss = self.ce(
114-
target=torch.argmax(disc_tar.int(), dim=-1), input=disc_pred
115-
)
116-
return gen_loss + disc_loss
117-
11899

119100
def filter_dict(d, filter_key):
120101
return {
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
class_path: chebai.models.ElectraPre
2+
init_args:
3+
criterion:
4+
class_path: chebai.loss.pretraining.ElectraPreLoss
5+
out_dim: null
6+
optimizer_kwargs:
7+
lr: 1e-4
8+
config:
9+
generator:
10+
vocab_size: 1400
11+
max_position_embeddings: 1800
12+
num_attention_heads: 8
13+
num_hidden_layers: 6
14+
type_vocab_size: 1
15+
discriminator:
16+
vocab_size: 1400
17+
max_position_embeddings: 1800
18+
num_attention_heads: 8
19+
num_hidden_layers: 6
20+
type_vocab_size: 1

configs/training/electra_pretraining.template.yaml

-31
This file was deleted.

0 commit comments

Comments
 (0)