-
Notifications
You must be signed in to change notification settings - Fork 2
/
bigbird.patch
97 lines (93 loc) · 5.62 KB
/
bigbird.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py
index 2a7e86aa8..47fba01ba 100755
--- a/src/transformers/models/big_bird/modeling_big_bird.py
+++ b/src/transformers/models/big_bird/modeling_big_bird.py
@@ -27,6 +27,8 @@ from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from itertools import permutations
+
from ...activations import ACT2FN
from ...file_utils import (
ModelOutput,
@@ -439,6 +441,52 @@ class BigBirdBlockSparseAttention(nn.Module):
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias)
+
+ self.rand_attn_tables = [ [] for _ in range( config.max_position_embeddings // config.block_size - 1 ) ]
+ self.generate_rand_attn_tables(self.max_seqlen, self.max_seqlen, self.block_size, self.block_size, self.num_random_blocks, 1024)
+ self.rand_attn_tables_prepared_arg = [self.max_seqlen, self.max_seqlen, self.block_size, self.block_size, self.num_random_blocks, 1024]
+
+ def generate_one_table(self, o_table, start_i, end_i, num_rand_blocks):
+ return list(permutations(o_table[start_i:end_i], num_rand_blocks))
+
+ def generate_rand_attn_tables(self, from_seq_length, to_seq_length, from_block_size, to_block_size, num_rand_blocks, last_idx):
+ all_tables = self.rand_attn_tables
+ middle_seq = np.arange(1, to_seq_length // to_block_size - 1, dtype=np.int32)
+ last = to_seq_length // to_block_size - 1
+ if last_idx > (2 * to_block_size):
+ last = (last_idx // to_block_size) - 1
+ for i in range(1, from_seq_length // from_block_size - 1):
+ start = i - 2
+ end = i
+ if i == 1:
+ # all_tables[i - 1] = np.random.permutation(middle_seq[2:last])[:r]
+ all_tables[i-1] = self.generate_one_table(middle_seq, 2, last, num_rand_blocks)
+ elif i == 2:
+ # rand_attn[i - 1, :] = np.random.permutation(middle_seq[3:last])[:r]
+ all_tables[i-1] = self.generate_one_table(middle_seq, 3, last, num_rand_blocks)
+ elif i == from_seq_length // from_block_size - 3:
+ # rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r]
+ all_tables[i-1] = self.generate_one_table(middle_seq, 0, last, num_rand_blocks)
+ # Missing -3: should have been sliced till last-3
+ elif i == from_seq_length // from_block_size - 2:
+ # rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r]
+ all_tables[i-1] = self.generate_one_table(middle_seq, 0, last, num_rand_blocks)
+ # Missing -4: should have been sliced till last-4
+ else:
+ if start > last:
+ start = last
+ # rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r]
+ all_tables[i-1] = self.generate_one_table(middle_seq, 0, start, num_rand_blocks)
+ elif (end + 1) == last:
+ # rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r]
+ all_tables[i-1] = self.generate_one_table(middle_seq, 0, start, num_rand_blocks)
+ else:
+ # rand_attn[i - 1, :] = np.random.permutation(
+ # np.concatenate((middle_seq[:start], middle_seq[end + 1 : last]))
+ # )[:r]
+ new_middle_seq = np.concatenate((middle_seq[:start], middle_seq[end + 1 : last]))
+ all_tables[i-1] = self.generate_one_table(new_middle_seq, 0, len(new_middle_seq), num_rand_blocks)
+
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
@@ -567,12 +615,7 @@ class BigBirdBlockSparseAttention(nn.Module):
# generate random attention and corresponding masks
np.random.seed(seed)
if from_seq_len in [1024, 3072, 4096]: # old plans used in paper
- rand_attn = [
- self._bigbird_block_rand_mask(
- self.max_seqlen, self.max_seqlen, from_block_size, to_block_size, n_rand_blocks, last_idx=1024
- )[: (from_seq_len // from_block_size - 2)]
- for _ in range(n_heads)
- ]
+ rand_attn = self._bigbird_block_rand_mask_fast(from_seq_len, self.max_seqlen, from_block_size, to_block_size, n_rand_blocks, 1024, n_heads)
else:
if plan_from_length is None:
plan_from_length, plan_num_rand_blocks = self._get_rand_attn_plan(
@@ -1056,6 +1099,16 @@ class BigBirdBlockSparseAttention(nn.Module):
return plan_from_length, plan_num_rand_blocks
+ def _bigbird_block_rand_mask_fast(self,
+ from_seq_length, to_seq_length, from_block_size, to_block_size, num_rand_blocks, last_idx, n_heads
+ ):
+ rand_attn = [ np.zeros((from_seq_length // from_block_size - 2, num_rand_blocks), dtype=np.int32) for _ in range(n_heads)]
+ for i in range(1, from_seq_length // from_block_size - 1):
+ rand_i = np.random.randint(len(self.rand_attn_tables[i-1]), size=n_heads)
+ for j, rand_index in enumerate(rand_i):
+ rand_attn[j][i-1] = self.rand_attn_tables[i-1][rand_index]
+ return rand_attn
+
@staticmethod
def _bigbird_block_rand_mask(
from_seq_length, to_seq_length, from_block_size, to_block_size, num_rand_blocks, last_idx=-1