-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathattention.py
240 lines (198 loc) · 9.63 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import numpy as np
from typing import Optional, Tuple
class TdnnAffine(torch.nn.Module):
""" An implemented tdnn affine component by conv1d
y = splice(w * x, context) + b
@input_dim: number of dims of frame <=> inputs channels of conv
@output_dim: number of layer nodes <=> outputs channels of conv
@context: a list of context
e.g. [-2,0,2]
If context is [0], then the TdnnAffine is equal to linear layer.
"""
def __init__(self, input_dim, output_dim, context=[0], bias=True, pad=True, stride=1, groups=1, norm_w=False, norm_f=False):
super(TdnnAffine, self).__init__()
assert input_dim % groups == 0
# Check to make sure the context sorted and has no duplicated values
for index in range(0, len(context) - 1):
if(context[index] >= context[index + 1]):
raise ValueError("Context tuple {} is invalid, such as the order.".format(context))
self.input_dim = input_dim
self.output_dim = output_dim
self.context = context
self.bool_bias = bias
self.pad = pad
self.groups = groups
self.norm_w = norm_w
self.norm_f = norm_f
# It is used to subsample frames with this factor
self.stride = stride
self.left_context = context[0] if context[0] < 0 else 0
self.right_context = context[-1] if context[-1] > 0 else 0
self.tot_context = self.right_context - self.left_context + 1
# Do not support sphereConv now.
if self.tot_context > 1 and self.norm_f:
self.norm_f = False
print("Warning: do not support sphereConv now and set norm_f=False.")
kernel_size = (self.tot_context,)
self.weight = torch.nn.Parameter(torch.randn(output_dim, input_dim//groups, *kernel_size))
if self.bool_bias:
self.bias = torch.nn.Parameter(torch.randn(output_dim))
else:
self.register_parameter('bias', None)
# init weight and bias. It is important
self.init_weight()
# Save GPU memory for no skiping case
if len(context) != self.tot_context:
# Used to skip some frames index according to context
self.mask = torch.tensor([[[ 1 if index in context else 0 \
for index in range(self.left_context, self.right_context + 1) ]]])
else:
self.mask = None
## Deprecated: the broadcast method could be used to save GPU memory,
# self.mask = torch.randn(output_dim, input_dim, 0)
# for index in range(self.left_context, self.right_context + 1):
# if index in context:
# fixed_value = torch.ones(output_dim, input_dim, 1)
# else:
# fixed_value = torch.zeros(output_dim, input_dim, 1)
# self.mask=torch.cat((self.mask, fixed_value), dim = 2)
# Save GPU memory of thi case.
self.selected_device = False
def init_weight(self):
# Note, var should be small to avoid slow-shrinking
torch.nn.init.normal_(self.weight, 0., 0.01)
if self.bias is not None:
torch.nn.init.constant_(self.bias, 0.)
def forward(self, inputs):
"""
@inputs: a 3-dimensional tensor (a batch), including [samples-index, frames-dim-index, frames-index]
"""
assert len(inputs.shape) == 3
assert inputs.shape[1] == self.input_dim
# Do not use conv1d.padding for self.left_context + self.right_context != 0 case.
if self.pad:
inputs = F.pad(inputs, (-self.left_context, self.right_context), mode="constant", value=0)
assert inputs.shape[2] >= self.tot_context
if not self.selected_device and self.mask is not None:
# To save the CPU -> GPU moving time
# Another simple case, for a temporary tensor, jus specify the device when creating it.
# such as, this_tensor = torch.tensor([1.0], device=inputs.device)
self.mask = to_device(self, self.mask)
self.selected_device = True
filters = self.weight * self.mask if self.mask is not None else self.weight
if self.norm_w:
filters = F.normalize(filters, dim=1)
if self.norm_f:
inputs = F.normalize(inputs, dim=1)
outputs = F.conv1d(inputs, filters, self.bias, stride=self.stride, padding=0, dilation=1, groups=self.groups)
return outputs
def extra_repr(self):
return '{input_dim}, {output_dim}, context={context}, bias={bool_bias}, stride={stride}, ' \
'pad={pad}, groups={groups}, norm_w={norm_w}, norm_f={norm_f}'.format(**self.__dict__)
@classmethod
def thop_count(self, m, x, y):
x = x[0]
kernel_ops = torch.zeros(m.weight.size()[2:]).numel() # Kw x Kh
bias_ops = 1 if m.bias is not None else 0
# N x Cout x H x W x (Cin x Kw x Kh + bias)
total_ops = y.nelement() * (m.input_dim * kernel_ops + bias_ops)
m.total_ops += torch.DoubleTensor([int(total_ops)])
# Attention-based
class AttentionAlphaComponent(torch.nn.Module):
"""Compute the alpha with attention module.
alpha = softmax(v'·f(w·x + b) + k) or softmax(v'·x + k)
where f is relu here and bias could be lost.
Support:
1. Single or Multi-head attention
2. One affine or two affine
3. Share weight (last affine = vector) or un-shared weight (last affine = matrix)
4. Self-attention or time context attention (supported by context parameter of TdnnAffine)
5. Different temperatures for different heads.
"""
def __init__(self, input_dim, num_head=1, split_input=True, share=True, affine_layers=2,
hidden_size=64, context=[0], bias=True, temperature=False, fixed=True):
super(AttentionAlphaComponent, self).__init__()
assert num_head >= 1
# Multi-head case.
if num_head > 1:
if split_input:
# Make sure fatures/planes with input_dim dims could be splited to num_head parts.
assert input_dim % num_head == 0
if temperature:
if fixed:
t_list = []
for i in range(num_head):
t_list.append([[max(1, (i // 2) * 5)]])
# shape [1, num_head, 1, 1]
self.register_buffer('t', torch.tensor([t_list]))
else:
# Different heads have different temperature.
# Use 1 + self.t**2 in forward to make sure temperature >= 1.
self.t = torch.nn.Parameter(torch.zeros(1, num_head, 1, 1))
self.input_dim = input_dim
self.num_head = num_head
self.split_input = split_input
self.share = share
self.temperature = temperature
self.fixed = fixed
if share:
# weight: [input_dim, 1] or [input_dim, hidden_size] -> [hidden_size, 1]
final_dim = 1
elif split_input:
# weight: [input_dim, input_dim // num_head] or [input_dim, hidden_size] -> [hidden_size, input_dim // num_head]
final_dim = input_dim // num_head
else:
# weight: [input_dim, input_dim] or [input_dim, hidden_size] -> [hidden_size, input_dim]
final_dim = input_dim
first_groups = 1
last_groups = 1
if affine_layers == 1:
last_affine_input_dim = input_dim
# (x, 1) for global case and (x, h) for split case.
if num_head > 1 and split_input:
last_groups = num_head
self.relu_affine = False
elif affine_layers == 2:
last_affine_input_dim = hidden_size * num_head
if num_head > 1:
# (1, h) for global case and (h, h) for split case.
last_groups = num_head
if split_input:
first_groups = num_head
# Add a relu-affine with affine_layers=2.
self.relu_affine = True
self.first_affine = TdnnAffine(input_dim, last_affine_input_dim, context=context, bias=bias, groups=first_groups)
# self.first_affine = torch.nn.Linear(input_dim, last_affine_input_dim)
self.relu = torch.nn.ReLU(inplace=True)
else:
raise ValueError("Expected 1 or 2 affine layers, but got {}.",format(affine_layers))
self.last_affine = TdnnAffine(last_affine_input_dim, final_dim * num_head, context=context, bias=bias, groups=last_groups)
# self.last_affine = torch.nn.Linear(last_affine_input_dim, final_dim * num_head)
# Dim=2 means to apply softmax in different frames-index (batch is a 3-dim tensor in this case).
self.softmax = torch.nn.Softmax(dim=2)
def forward(self, inputs):
"""
@inputs: a 3-dimensional tensor (a batch), including [samples-index, frames-dim-index, frames-index]
"""
assert len(inputs.shape) == 3
assert inputs.shape[1] == self.input_dim
if self.temperature:
batch_size = inputs.shape[0]
chunk_size = inputs.shape[2]
x = inputs
if self.relu_affine:
x = self.relu(self.first_affine(x))
if self.num_head > 1 and self.temperature:
if self.fixed:
t = self.t
else:
t = 1 + self.t**2
x = self.last_affine(x).reshape(batch_size, self.num_head, -1, chunk_size) / t
return self.softmax(x.reshape(batch_size, -1, chunk_size))
else:
return self.softmax(self.last_affine(x))