@@ -39,26 +39,29 @@ def __init__(self, p=0.2, **kwargs):
39
39
self .replace_p = 0.1
40
40
41
41
def _get_data_and_labels (self , batch , batch_idx ):
42
- return dict (features = batch .x , labels = None )
42
+ return dict (features = batch .x , labels = None , lens = torch . tensor ( batch . lens ) )
43
43
44
44
def forward (self , data ):
45
- self . batch_size = data . shape [ 0 ]
46
- x = torch . clone ( data )
45
+ x = torch . clone ( data [ "features" ])
46
+ self . batch_size = x . shape [ 0 ]
47
47
gen_tar = []
48
48
dis_tar = []
49
- for i in range (x .shape [0 ]):
50
- j = random .randint (0 , x .shape [1 ]- 1 )
51
- t = x [i ,j ]
49
+ lens = data ["lens" ]
50
+ max_len = x .shape [1 ]
51
+ mask = torch .arange (max_len ).expand (len (lens ), max_len ) < lens .unsqueeze (1 )
52
+ for i , l in enumerate (lens ):
53
+ j = random .randint (0 , l )
54
+ t = x [i ,j ].item ()
52
55
x [i ,j ] = 0
53
56
gen_tar .append (t )
54
57
dis_tar .append (j )
55
- gen_out = torch .max (torch .sum (self .generator (x ).logits ,dim = 1 ), dim = - 1 )[1 ]
58
+ gen_out = torch .max (torch .sum (self .generator (x , attention_mask = mask ).logits ,dim = 1 ), dim = - 1 )[1 ]
56
59
with torch .no_grad ():
57
60
xc = x .clone ()
58
61
for i in range (x .shape [0 ]):
59
62
xc [i ,dis_tar [i ]] = gen_out [i ]
60
63
replaced_by_different = torch .ne (x , xc )
61
- disc_out = self .discriminator (xc )
64
+ disc_out = self .discriminator (xc , attention_mask = mask )
62
65
return (self .generator .electra .embeddings (gen_out .unsqueeze (- 1 )), disc_out .logits ), (self .generator .electra .embeddings (torch .tensor (gen_tar , device = self .device ).unsqueeze (- 1 )), replaced_by_different .float ())
63
66
64
67
def _get_prediction_and_labels (self , batch , labels , output ):
0 commit comments