Skip to content

Commit 1826aa0

Browse files
committed
Fix electra pretraining
1 parent a2b051c commit 1826aa0

File tree

5 files changed

+45
-55
lines changed

5 files changed

+45
-55
lines changed

chebai/loss/pretraining.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
3+
class ElectraPreLoss(torch.nn.Module):
4+
def __init__(self):
5+
super().__init__()
6+
self.ce = torch.nn.CrossEntropyLoss()
7+
8+
def forward(self, input, target, **loss_kwargs):
9+
t, p = input
10+
gen_pred, disc_pred = t
11+
gen_tar, disc_tar = p
12+
gen_loss = self.ce(target=torch.argmax(gen_tar.int(), dim=-1), input=gen_pred)
13+
disc_loss = self.ce(
14+
target=torch.argmax(disc_tar.int(), dim=-1), input=disc_pred
15+
)
16+
return gen_loss + disc_loss
17+

chebai/models/base.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,13 @@ def __init_subclass__(cls, **kwargs):
4141
def _get_prediction_and_labels(self, data, labels, output):
4242
return output, labels
4343

44+
def _process_labels_in_batch(self, batch):
45+
return batch.y.float()
46+
4447
def _process_batch(self, batch, batch_idx):
4548
return dict(
4649
features=batch.x,
47-
labels=batch.y.float(),
50+
labels=self._process_labels_in_batch(batch),
4851
model_kwargs=batch.additional_fields["model_kwargs"],
4952
loss_kwargs=batch.additional_fields["loss_kwargs"],
5053
idents=batch.additional_fields["idents"],

chebai/models/electra.py

+4-23
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,15 @@ def __init__(self, config=None, **kwargs):
4747
def as_pretrained(self):
4848
return self.discriminator
4949

50-
def _process_batch(self, batch, batch_idx):
51-
return dict(features=batch.x, labels=None, mask=batch.mask)
50+
def _process_labels_in_batch(self, batch):
51+
return None
5252

53-
def forward(self, data):
53+
def forward(self, data, **kwargs):
5454
features = data["features"]
5555
self.batch_size = batch_size = features.shape[0]
5656
max_seq_len = features.shape[1]
5757

58-
mask = data["mask"]
58+
mask = kwargs["mask"]
5959
with torch.no_grad():
6060
dis_tar = (
6161
torch.rand((batch_size,), device=self.device) * torch.sum(mask, dim=-1)
@@ -96,25 +96,6 @@ def forward(self, data):
9696
def _get_prediction_and_labels(self, batch, labels, output):
9797
return torch.softmax(output[0][1], dim=-1), output[1][1].int()
9898

99-
def _get_data_for_loss(self, model_output, labels):
100-
return dict(input=model_output, target=None)
101-
102-
103-
class ElectraPreLoss(torch.nn.Module):
104-
def __init__(self):
105-
super().__init__()
106-
self.ce = torch.nn.CrossEntropyLoss()
107-
108-
def forward(self, input, target):
109-
t, p = input
110-
gen_pred, disc_pred = t
111-
gen_tar, disc_tar = p
112-
gen_loss = self.ce(target=torch.argmax(gen_tar.int(), dim=-1), input=gen_pred)
113-
disc_loss = self.ce(
114-
target=torch.argmax(disc_tar.int(), dim=-1), input=disc_pred
115-
)
116-
return gen_loss + disc_loss
117-
11899

119100
def filter_dict(d, filter_key):
120101
return {
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
class_path: chebai.models.ElectraPre
2+
init_args:
3+
criterion:
4+
class_path: chebai.loss.pretraining.ElectraPreLoss
5+
out_dim: null
6+
optimizer_kwargs:
7+
lr: 1e-4
8+
config:
9+
generator:
10+
vocab_size: 1400
11+
max_position_embeddings: 1800
12+
num_attention_heads: 8
13+
num_hidden_layers: 6
14+
type_vocab_size: 1
15+
discriminator:
16+
vocab_size: 1400
17+
max_position_embeddings: 1800
18+
num_attention_heads: 8
19+
num_hidden_layers: 6
20+
type_vocab_size: 1

configs/training/electra_pretraining.template.yaml

-31
This file was deleted.

0 commit comments

Comments
 (0)