13
13
#else
14
14
#include < ATen/ops/empty.h>
15
15
#include < ATen/ops/zeros.h>
16
+ #include < ATen/ops/cat.h>
17
+ #include < ATen/ops/cumsum.h>
16
18
#endif
17
19
18
20
namespace at {
@@ -68,7 +70,7 @@ Tensor _get_complete_sum(const Tensor& lengths) {
68
70
offsets[0 ].zero_ ();
69
71
70
72
AT_DISPATCH_INDEX_TYPES (
71
- lengths.type (), " _segment_reduce_cuda_backward_kernel1" , ([&] {
73
+ lengths.scalar_type (), " _segment_reduce_cuda_backward_kernel1" , ([&] {
72
74
auto * lengths_data_ptr = lengths.data_ptr <index_t >();
73
75
auto * offsets_data_ptr = offsets.data_ptr <index_t >();
74
76
at::cuda::cub::inclusive_sum (
@@ -108,22 +110,34 @@ __global__ void segment_reduce_forward_kernel(
108
110
const index_t * lengths_data,
109
111
const index_t * lengths_cumsum_data,
110
112
const int64_t segment_count,
111
- const int64_t stride_count ,
113
+ const int64_t lengths_stride_axis ,
112
114
bool is_initial_set,
113
- scalar_t initial_value) {
115
+ scalar_t initial_value,
116
+ const int64_t outer_offset,
117
+ const int64_t inner_offset,
118
+ const int64_t data_stride_axis,
119
+ const int64_t data_size_axis,
120
+ const int64_t output_stride_axis,
121
+ const int64_t output_size_axis,
122
+ const int64_t lengths_cumsum_stride_axis) {
114
123
int64_t idx = blockIdx .x * blockDim .x + threadIdx .x ;
115
- int64_t row_id = idx / stride_count;
116
- int64_t lane_id = idx % stride_count;
117
- if (idx >= (segment_count * stride_count)) {
124
+ if (idx >= (outer_offset * segment_count * inner_offset)) {
118
125
return ;
119
126
}
120
- int64_t offset_start = lengths_cumsum_data[row_id];
121
- int64_t offset_end = lengths_cumsum_data[row_id + 1 ];
127
+ int64_t row_id = idx / inner_offset;
128
+ int64_t lane_id = idx % inner_offset; // lane_id is the inner_idx
129
+ int64_t outer_idx = row_id / segment_count;
130
+ int64_t dim_idx = row_id % segment_count;
131
+
132
+ int64_t offset_idx = outer_idx * lengths_cumsum_stride_axis * (segment_count + 1 ) + dim_idx;
133
+ index_t offset_start = lengths_cumsum_data[offset_idx];
134
+ index_t offset_end = lengths_cumsum_data[offset_idx + 1 ];
122
135
123
136
// ===== step2: apply reduction
124
- for (int64_t j = offset_start; j < offset_end; ++j) {
125
- int64_t starting_index = (j * stride_count) + lane_id;
126
- const auto data = values_data[starting_index];
137
+ for (index_t j = offset_start; j < offset_end; ++j) {
138
+ int64_t data_index = outer_idx * data_stride_axis * data_size_axis
139
+ + j * data_stride_axis + lane_id;
140
+ const auto data = values_data[data_index];
127
141
// TODO: There is no need to branch with every element
128
142
if (reduction == SegmentReductionType::MAX) {
129
143
initial_value =
@@ -142,19 +156,22 @@ __global__ void segment_reduce_forward_kernel(
142
156
}
143
157
144
158
// ===== step3: finalize reduction
145
- CUDA_KERNEL_ASSERT (lengths_data[row_id] >= 0 );
146
- if (lengths_data[row_id] == 0 && !is_initial_set &&
159
+ int64_t lengths_idx = outer_idx * lengths_stride_axis * segment_count + dim_idx;
160
+ CUDA_KERNEL_ASSERT (lengths_data[lengths_idx] >= 0 );
161
+ if (lengths_data[lengths_idx] == 0 && !is_initial_set &&
147
162
reduction == SegmentReductionType::MEAN) {
148
163
initial_value = static_cast <scalar_t >(NAN);
149
164
} else if (
150
- reduction == SegmentReductionType::MEAN && lengths_data[row_id ] > 0 &&
165
+ reduction == SegmentReductionType::MEAN && lengths_data[lengths_idx ] > 0 &&
151
166
!at::_isnan (initial_value)) {
152
- initial_value = initial_value / lengths_data[row_id ];
167
+ initial_value = initial_value / lengths_data[lengths_idx ];
153
168
}
154
- int64_t output_index = (row_id * stride_count) + lane_id;
169
+ int64_t output_index = outer_idx * output_stride_axis * output_size_axis
170
+ + dim_idx * output_stride_axis + lane_id;
155
171
output_data[output_index] = initial_value;
156
172
}
157
173
174
+
158
175
template <typename scalar_t , typename index_t >
159
176
__global__ void segment_reduce_backward_kernel (
160
177
SegmentReductionType reduction,
@@ -165,32 +182,46 @@ __global__ void segment_reduce_backward_kernel(
165
182
const index_t * lengths_data,
166
183
const index_t * lengths_cumsum_data,
167
184
const int64_t segment_count,
168
- const int64_t stride_count,
169
- scalar_t initial_prod_value) {
185
+ const int64_t lengths_stride_axis,
186
+ scalar_t initial_prod_value,
187
+ const int64_t outer_offset,
188
+ const int64_t inner_offset,
189
+ const int64_t data_stride_axis,
190
+ const int64_t data_size_axis,
191
+ const int64_t output_stride_axis,
192
+ const int64_t output_size_axis,
193
+ const int64_t lengths_cumsum_stride_axis) {
170
194
int64_t idx = blockIdx .x * blockDim .x + threadIdx .x ;
171
- int64_t row_id = idx / stride_count;
172
- int64_t lane_id = idx % stride_count;
173
-
174
- if (idx >= (segment_count * stride_count)) {
195
+ if (idx >= (outer_offset * segment_count * inner_offset)) {
175
196
return ;
176
197
}
177
- if (lengths_data[row_id] == 0 ) {
198
+ int64_t row_id = idx / inner_offset;
199
+ int64_t lane_id = idx % inner_offset; // lane_id is the inner_idx
200
+ int64_t outer_idx = row_id / segment_count;
201
+ int64_t dim_idx = row_id % segment_count;
202
+
203
+ int64_t lengths_idx = outer_idx * lengths_stride_axis * segment_count + dim_idx;
204
+ auto segment_length = lengths_data[lengths_idx];
205
+ if (segment_length == 0 ) {
178
206
return ;
179
207
}
180
208
181
- int64_t offset_start = lengths_cumsum_data[row_id];
182
- int64_t offset_end = lengths_cumsum_data[row_id + 1 ];
209
+ int64_t offset_idx = outer_idx * lengths_cumsum_stride_axis * (segment_count + 1 ) + dim_idx;
210
+ index_t offset_start = lengths_cumsum_data[offset_idx];
211
+ index_t offset_end = lengths_cumsum_data[offset_idx + 1 ];
183
212
184
- int64_t output_index = (row_id * stride_count) + lane_id;
213
+ int64_t output_index = outer_idx * output_stride_axis * output_size_axis
214
+ + dim_idx * output_stride_axis + lane_id;
185
215
186
216
if (reduction == SegmentReductionType::MAX ||
187
217
reduction == SegmentReductionType::MIN) {
188
218
int64_t counter = 0 ;
189
219
for (int64_t j = offset_start; j < offset_end; ++j) {
190
- int64_t starting_index = (j * stride_count) + lane_id;
191
- if (at::_isnan (values_data[starting_index]) ||
192
- values_data[starting_index] == output_data[output_index]) {
193
- grad_input_data[starting_index] = grad_data[output_index];
220
+ int64_t data_index = outer_idx * data_stride_axis * data_size_axis
221
+ + j * data_stride_axis + lane_id;
222
+ if (at::_isnan (values_data[data_index]) ||
223
+ values_data[data_index] == output_data[output_index]) {
224
+ grad_input_data[data_index] = grad_data[output_index];
194
225
counter++;
195
226
}
196
227
}
@@ -200,47 +231,51 @@ __global__ void segment_reduce_backward_kernel(
200
231
return ;
201
232
}
202
233
for (int64_t j = offset_start; j < offset_end; ++j) {
203
- int64_t starting_index = (j * stride_count) + lane_id;
204
- if (grad_input_data[starting_index] > 0 ) {
205
- grad_input_data[starting_index] =
206
- grad_input_data[starting_index] / counter;
234
+ int64_t data_index = outer_idx * data_stride_axis * data_size_axis
235
+ + j * data_stride_axis + lane_id;
236
+ if (grad_input_data[data_index] > 0 ) {
237
+ grad_input_data[data_index] =
238
+ grad_input_data[data_index] / counter;
207
239
}
208
240
}
209
241
} else if (reduction == SegmentReductionType::MEAN) {
210
- auto grad_val = grad_data[output_index] / lengths_data[row_id] ;
242
+ auto grad_val = grad_data[output_index] / segment_length ;
211
243
for (int64_t j = offset_start; j < offset_end; ++j) {
212
- int64_t starting_index = (j * stride_count) + lane_id;
213
- grad_input_data[starting_index] = grad_val;
244
+ int64_t data_index = outer_idx * data_stride_axis * data_size_axis
245
+ + j * data_stride_axis + lane_id;
246
+ grad_input_data[data_index] = grad_val;
214
247
}
215
248
} else if (reduction == SegmentReductionType::SUM) {
216
249
const auto & grad_val = grad_data[output_index];
217
250
for (int64_t j = offset_start; j < offset_end; ++j) {
218
- int64_t starting_index = (j * stride_count) + lane_id;
219
- grad_input_data[starting_index] = grad_val;
251
+ int64_t data_index = outer_idx * data_stride_axis * data_size_axis
252
+ + j * data_stride_axis + lane_id;
253
+ grad_input_data[data_index] = grad_val;
220
254
}
221
255
} else if (reduction == SegmentReductionType::PROD) {
222
256
const auto & grad_val = grad_data[output_index] * output_data[output_index];
223
257
for (int64_t j = offset_start; j < offset_end; ++j) {
224
- int64_t starting_index = (j * stride_count) + lane_id;
225
- if (at::_isnan (values_data[starting_index]) ||
226
- values_data[starting_index] == 0 ) {
258
+ int64_t data_index = outer_idx * data_stride_axis * data_size_axis
259
+ + j * data_stride_axis + lane_id;
260
+ if (at::_isnan (values_data[data_index]) ||
261
+ values_data[data_index] == 0 ) {
227
262
// explicitly compute exclusive prod
228
263
scalar_t exclusive_prod = initial_prod_value;
229
- int64_t idx ;
264
+ int64_t prod_idx ;
230
265
for (int64_t k = offset_start; k < offset_end; ++k) {
231
266
if (k != j) {
232
- idx = (k * stride_count) + lane_id;
233
- exclusive_prod *= values_data[idx];
267
+ prod_idx = outer_idx * data_stride_axis * data_size_axis
268
+ + k * data_stride_axis + lane_id;
269
+ exclusive_prod *= values_data[prod_idx];
234
270
}
235
271
}
236
- grad_input_data[starting_index ] = grad_data[output_index] * exclusive_prod;
272
+ grad_input_data[data_index ] = grad_data[output_index] * exclusive_prod;
237
273
} else {
238
- grad_input_data[starting_index ] = grad_val / values_data[starting_index ];
274
+ grad_input_data[data_index ] = grad_val / values_data[data_index ];
239
275
}
240
276
}
241
277
}
242
278
}
243
-
244
279
} // namespace
245
280
246
281
Tensor _segment_reduce_cuda_backward_kernel (
@@ -251,28 +286,43 @@ Tensor _segment_reduce_cuda_backward_kernel(
251
286
const Tensor& lengths_contig,
252
287
int64_t axis,
253
288
const c10::optional<Scalar>& initial) {
254
- int64_t segment_count = lengths_contig.numel () ;
255
- auto output_shape = data_contig. sizes (). vec ( );
256
- output_shape[axis] = segment_count ;
289
+ axis = lengths_contig.dim () - 1 ;
290
+ int64_t segment_count = lengths_contig. size (axis );
291
+ int64_t lengths_stride_axis = lengths_contig. stride (axis) ;
257
292
auto grad_input = at::zeros ({data_contig.sizes ()}, grad_contig.options ());
258
293
259
- int64_t stride_count = data_contig.numel () / data_contig.size (axis);
294
+ auto zeros_shape = lengths_contig.sizes ().vec ();
295
+ zeros_shape[axis] = 1 ;
296
+ auto offsets = at::cat ({at::zeros (zeros_shape, lengths_contig.options ()), lengths_contig}, axis);
297
+ offsets.cumsum_ (axis);
260
298
261
- auto offsets = _get_complete_sum (lengths_contig);
299
+ // outer_offset is the size of the outer dimensions of output (before axis)
300
+ // inner_offset is the size of the inner dimensions of output (after axis)
301
+ int64_t outer_offset = 1 , inner_offset = 1 ;
302
+ for (int64_t d = 0 ; d < axis; d++) {
303
+ outer_offset *= output_contig.size (d);
304
+ }
305
+ for (int64_t d = axis + 1 ; d < output_contig.dim (); d++) {
306
+ inner_offset *= output_contig.size (d);
307
+ }
262
308
263
309
constexpr int threads_per_block = 256 ;
264
- int64_t num_blocks =
265
- ((segment_count * stride_count) + threads_per_block - 1 ) /
266
- threads_per_block;
310
+ int64_t num_blocks = (outer_offset * inner_offset * segment_count + threads_per_block - 1 ) / threads_per_block;
267
311
268
312
num_blocks = std::max (num_blocks, (int64_t )1 );
269
313
314
+ auto data_stride_axis = data_contig.stride (axis);
315
+ auto data_size_axis = data_contig.size (axis);
316
+ auto output_stride_axis = output_contig.stride (axis);
317
+ auto output_size_axis = output_contig.size (axis);
318
+ auto offsets_stride_axis = offsets.stride (axis);
319
+
270
320
AT_DISPATCH_INDEX_TYPES (
271
- lengths_contig.type (), " _segment_reduce_cuda_backward_kernel1" , ([&] {
321
+ lengths_contig.scalar_type (), " _segment_reduce_cuda_backward_kernel1" , ([&] {
272
322
const auto * lengths_data = lengths_contig.data_ptr <index_t >();
273
323
auto * offsets_data = offsets.data_ptr <index_t >();
274
324
275
- // TODO: Swtich to TensorIterator for better maintainablility and
325
+ // TODO: Switch to TensorIterator for better maintainablility and
276
326
// readability
277
327
AT_DISPATCH_FLOATING_TYPES_AND2 (
278
328
kBFloat16 ,
@@ -305,8 +355,16 @@ Tensor _segment_reduce_cuda_backward_kernel(
305
355
lengths_data,
306
356
offsets_data,
307
357
segment_count,
308
- stride_count,
309
- initial_prod_value);
358
+ lengths_stride_axis,
359
+ initial_prod_value,
360
+ outer_offset,
361
+ inner_offset,
362
+ data_stride_axis,
363
+ data_size_axis,
364
+ output_stride_axis,
365
+ output_size_axis,
366
+ offsets_stride_axis
367
+ );
310
368
C10_CUDA_KERNEL_LAUNCH_CHECK ();
311
369
}));
312
370
}));
@@ -319,24 +377,46 @@ Tensor _segment_reduce_cuda_kernel(
319
377
const Tensor& lengths,
320
378
int64_t axis,
321
379
const c10::optional<Scalar>& initial) {
322
- int64_t segment_count = lengths.numel ();
380
+ // data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel
381
+ TORCH_CHECK (data.is_contiguous (), " Expected data to be contiguous." );
382
+ TORCH_CHECK (lengths.is_contiguous (), " Expected lengths to be contiguous." );
383
+ axis = lengths.dim () - 1 ;
384
+ int64_t segment_count = lengths.size (axis);
385
+ int64_t lengths_stride_axis = lengths.stride (axis);
323
386
auto output_shape = data.sizes ().vec ();
324
387
output_shape[axis] = segment_count;
325
388
auto output = at::empty (output_shape, data.options ());
326
389
327
- int64_t stride_count = data.numel () / data.size (axis);
328
-
329
- auto offsets = _get_complete_sum (lengths);
390
+ // _get_complete_sum only supports 1D?
391
+ auto zeros_shape = lengths.sizes ().vec ();
392
+ zeros_shape[axis] = 1 ;
393
+ auto offsets = at::cat ({at::zeros (zeros_shape, lengths.options ()), lengths}, axis);
394
+ offsets.cumsum_ (axis);
395
+
396
+ // outer_offset is the size of the outer dimensions of output (before axis)
397
+ // inner_offset is the size of the inner dimensions of output (after axis)
398
+ int64_t outer_offset = 1 , inner_offset = 1 ;
399
+ for (int64_t d = 0 ; d < axis; d++) {
400
+ outer_offset *= output.size (d);
401
+ }
402
+ for (int64_t d = axis + 1 ; d < output.dim (); d++) {
403
+ inner_offset *= output.size (d);
404
+ }
330
405
331
406
constexpr int threads_per_block = 256 ;
332
- int64_t num_blocks =
333
- ((segment_count * stride_count) + threads_per_block - 1 ) /
334
- threads_per_block;
407
+ // segment_count * stride_count is just output.numel() ?
408
+ int64_t num_blocks = (output.numel () + threads_per_block - 1 ) / threads_per_block;
335
409
336
410
num_blocks = std::max (num_blocks, (int64_t )1 );
337
411
412
+ auto data_stride_axis = data.stride (axis);
413
+ auto data_size_axis = data.size (axis);
414
+ auto output_stride_axis = output.stride (axis);
415
+ auto output_size_axis = output.size (axis);
416
+ auto offsets_stride_axis = offsets.stride (axis);
417
+
338
418
AT_DISPATCH_INDEX_TYPES (
339
- lengths.type (), " _segment_reduce_cuda_kernel1" , ([&] {
419
+ lengths.scalar_type (), " _segment_reduce_cuda_kernel1" , ([&] {
340
420
auto * offsets_data_ptr = offsets.data_ptr <index_t >();
341
421
auto * lengths_data_ptr = lengths.data_ptr <index_t >();
342
422
AT_DISPATCH_FLOATING_TYPES_AND2 (
@@ -376,9 +456,17 @@ Tensor _segment_reduce_cuda_kernel(
376
456
lengths_data_ptr,
377
457
offsets_data_ptr,
378
458
segment_count,
379
- stride_count ,
459
+ lengths_stride_axis ,
380
460
initial.has_value(),
381
- initial_value);
461
+ initial_value,
462
+ outer_offset,
463
+ inner_offset,
464
+ data_stride_axis,
465
+ data_size_axis,
466
+ output_stride_axis,
467
+ output_size_axis,
468
+ offsets_stride_axis
469
+ );
382
470
C10_CUDA_KERNEL_LAUNCH_CHECK ();
383
471
} else {
384
472
if (reduction == SegmentReductionType::MAX) {
0 commit comments