Skip to content

Commit 2f8f2aa

Browse files
committed
Fix bug in labeling
1 parent bd9071f commit 2f8f2aa

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ repos:
44
hooks:
55
- id: isort
66
- repo: https://github.com/psf/black
7-
rev: 21.9b0
7+
rev: 22.10
88
hooks:
99
- id: black

chebai/models/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _get_data_and_labels(self, batch, batch_idx):
6161
def _execute(self, batch, batch_idx):
6262
data = self._get_data_and_labels(batch, batch_idx)
6363
labels = data["labels"]
64-
model_output = self(data["features"], **data.get("model_kwargs", dict()))
64+
model_output = self(data, **data.get("model_kwargs", dict()))
6565
return data, labels, model_output
6666

6767
def calculate_metrics(self, data, labels, model_output):
@@ -70,7 +70,7 @@ def calculate_metrics(self, data, labels, model_output):
7070
pred, labels = self._get_prediction_and_labels(data, labels,
7171
model_output)
7272
f1 = self.f1(target=labels.int(), preds=torch.sigmoid(pred))
73-
mse = self.mse(labels, torch.sigmoid(pred))
73+
mse = self.mse(target=labels, preds=torch.sigmoid(pred))
7474
return loss, f1, mse
7575

7676
def training_step(self, *args, **kwargs):

chebai/models/electra.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,29 @@ def __init__(self, p=0.2, **kwargs):
3939
self.replace_p = 0.1
4040

4141
def _get_data_and_labels(self, batch, batch_idx):
42-
return dict(features=batch.x, labels=None)
42+
return dict(features=batch.x, labels=None, lens=torch.tensor(batch.lens))
4343

4444
def forward(self, data):
45-
self.batch_size = data.shape[0]
46-
x = torch.clone(data)
45+
x = torch.clone(data["features"])
46+
self.batch_size = x.shape[0]
4747
gen_tar = []
4848
dis_tar = []
49-
for i in range(x.shape[0]):
50-
j = random.randint(0, x.shape[1]-1)
51-
t = x[i,j]
49+
lens = data["lens"]
50+
max_len = x.shape[1]
51+
mask = torch.arange(max_len).expand(len(lens), max_len) < lens.unsqueeze(1)
52+
for i, l in enumerate(lens):
53+
j = random.randint(0, l)
54+
t = x[i,j].item()
5255
x[i,j] = 0
5356
gen_tar.append(t)
5457
dis_tar.append(j)
55-
gen_out = torch.max(torch.sum(self.generator(x).logits,dim=1), dim=-1)[1]
58+
gen_out = torch.max(torch.sum(self.generator(x, attention_mask=mask).logits,dim=1), dim=-1)[1]
5659
with torch.no_grad():
5760
xc = x.clone()
5861
for i in range(x.shape[0]):
5962
xc[i,dis_tar[i]] = gen_out[i]
6063
replaced_by_different = torch.ne(x, xc)
61-
disc_out = self.discriminator(xc)
64+
disc_out = self.discriminator(xc, attention_mask=mask)
6265
return (self.generator.electra.embeddings(gen_out.unsqueeze(-1)), disc_out.logits), (self.generator.electra.embeddings(torch.tensor(gen_tar, device=self.device).unsqueeze(-1)), replaced_by_different.float())
6366

6467
def _get_prediction_and_labels(self, batch, labels, output):

0 commit comments

Comments
 (0)