Skip to content

Commit e727539

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Support multi-dimensional lengths in segment_reduce to support pytorch_scatter.segment_* functionalities (CUDA)
Pull Request resolved: pytorch#77061 Approved by: https://github.com/cpuhrsch
1 parent 38350ac commit e727539

File tree

4 files changed

+171
-81
lines changed

4 files changed

+171
-81
lines changed

aten/src/ATen/native/cuda/SegmentReduce.cu

+158-70
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#else
1414
#include <ATen/ops/empty.h>
1515
#include <ATen/ops/zeros.h>
16+
#include <ATen/ops/cat.h>
17+
#include <ATen/ops/cumsum.h>
1618
#endif
1719

1820
namespace at {
@@ -68,7 +70,7 @@ Tensor _get_complete_sum(const Tensor& lengths) {
6870
offsets[0].zero_();
6971

7072
AT_DISPATCH_INDEX_TYPES(
71-
lengths.type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
73+
lengths.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
7274
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
7375
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
7476
at::cuda::cub::inclusive_sum(
@@ -108,22 +110,34 @@ __global__ void segment_reduce_forward_kernel(
108110
const index_t* lengths_data,
109111
const index_t* lengths_cumsum_data,
110112
const int64_t segment_count,
111-
const int64_t stride_count,
113+
const int64_t lengths_stride_axis,
112114
bool is_initial_set,
113-
scalar_t initial_value) {
115+
scalar_t initial_value,
116+
const int64_t outer_offset,
117+
const int64_t inner_offset,
118+
const int64_t data_stride_axis,
119+
const int64_t data_size_axis,
120+
const int64_t output_stride_axis,
121+
const int64_t output_size_axis,
122+
const int64_t lengths_cumsum_stride_axis) {
114123
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
115-
int64_t row_id = idx / stride_count;
116-
int64_t lane_id = idx % stride_count;
117-
if (idx >= (segment_count * stride_count)) {
124+
if (idx >= (outer_offset * segment_count * inner_offset)) {
118125
return;
119126
}
120-
int64_t offset_start = lengths_cumsum_data[row_id];
121-
int64_t offset_end = lengths_cumsum_data[row_id + 1];
127+
int64_t row_id = idx / inner_offset;
128+
int64_t lane_id = idx % inner_offset; // lane_id is the inner_idx
129+
int64_t outer_idx = row_id / segment_count;
130+
int64_t dim_idx = row_id % segment_count;
131+
132+
int64_t offset_idx = outer_idx * lengths_cumsum_stride_axis * (segment_count + 1) + dim_idx;
133+
index_t offset_start = lengths_cumsum_data[offset_idx];
134+
index_t offset_end = lengths_cumsum_data[offset_idx + 1];
122135

123136
// ===== step2: apply reduction
124-
for (int64_t j = offset_start; j < offset_end; ++j) {
125-
int64_t starting_index = (j * stride_count) + lane_id;
126-
const auto data = values_data[starting_index];
137+
for (index_t j = offset_start; j < offset_end; ++j) {
138+
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
139+
+ j * data_stride_axis + lane_id;
140+
const auto data = values_data[data_index];
127141
// TODO: There is no need to branch with every element
128142
if (reduction == SegmentReductionType::MAX) {
129143
initial_value =
@@ -142,19 +156,22 @@ __global__ void segment_reduce_forward_kernel(
142156
}
143157

144158
// ===== step3: finalize reduction
145-
CUDA_KERNEL_ASSERT(lengths_data[row_id] >= 0);
146-
if (lengths_data[row_id] == 0 && !is_initial_set &&
159+
int64_t lengths_idx = outer_idx * lengths_stride_axis * segment_count + dim_idx;
160+
CUDA_KERNEL_ASSERT(lengths_data[lengths_idx] >= 0);
161+
if (lengths_data[lengths_idx] == 0 && !is_initial_set &&
147162
reduction == SegmentReductionType::MEAN) {
148163
initial_value = static_cast<scalar_t>(NAN);
149164
} else if (
150-
reduction == SegmentReductionType::MEAN && lengths_data[row_id] > 0 &&
165+
reduction == SegmentReductionType::MEAN && lengths_data[lengths_idx] > 0 &&
151166
!at::_isnan(initial_value)) {
152-
initial_value = initial_value / lengths_data[row_id];
167+
initial_value = initial_value / lengths_data[lengths_idx];
153168
}
154-
int64_t output_index = (row_id * stride_count) + lane_id;
169+
int64_t output_index = outer_idx * output_stride_axis * output_size_axis
170+
+ dim_idx * output_stride_axis + lane_id;
155171
output_data[output_index] = initial_value;
156172
}
157173

174+
158175
template <typename scalar_t, typename index_t>
159176
__global__ void segment_reduce_backward_kernel(
160177
SegmentReductionType reduction,
@@ -165,32 +182,46 @@ __global__ void segment_reduce_backward_kernel(
165182
const index_t* lengths_data,
166183
const index_t* lengths_cumsum_data,
167184
const int64_t segment_count,
168-
const int64_t stride_count,
169-
scalar_t initial_prod_value) {
185+
const int64_t lengths_stride_axis,
186+
scalar_t initial_prod_value,
187+
const int64_t outer_offset,
188+
const int64_t inner_offset,
189+
const int64_t data_stride_axis,
190+
const int64_t data_size_axis,
191+
const int64_t output_stride_axis,
192+
const int64_t output_size_axis,
193+
const int64_t lengths_cumsum_stride_axis) {
170194
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
171-
int64_t row_id = idx / stride_count;
172-
int64_t lane_id = idx % stride_count;
173-
174-
if (idx >= (segment_count * stride_count)) {
195+
if (idx >= (outer_offset * segment_count * inner_offset)) {
175196
return;
176197
}
177-
if (lengths_data[row_id] == 0) {
198+
int64_t row_id = idx / inner_offset;
199+
int64_t lane_id = idx % inner_offset; // lane_id is the inner_idx
200+
int64_t outer_idx = row_id / segment_count;
201+
int64_t dim_idx = row_id % segment_count;
202+
203+
int64_t lengths_idx = outer_idx * lengths_stride_axis * segment_count + dim_idx;
204+
auto segment_length = lengths_data[lengths_idx];
205+
if (segment_length == 0) {
178206
return;
179207
}
180208

181-
int64_t offset_start = lengths_cumsum_data[row_id];
182-
int64_t offset_end = lengths_cumsum_data[row_id + 1];
209+
int64_t offset_idx = outer_idx * lengths_cumsum_stride_axis * (segment_count + 1) + dim_idx;
210+
index_t offset_start = lengths_cumsum_data[offset_idx];
211+
index_t offset_end = lengths_cumsum_data[offset_idx + 1];
183212

184-
int64_t output_index = (row_id * stride_count) + lane_id;
213+
int64_t output_index = outer_idx * output_stride_axis * output_size_axis
214+
+ dim_idx * output_stride_axis + lane_id;
185215

186216
if (reduction == SegmentReductionType::MAX ||
187217
reduction == SegmentReductionType::MIN) {
188218
int64_t counter = 0;
189219
for (int64_t j = offset_start; j < offset_end; ++j) {
190-
int64_t starting_index = (j * stride_count) + lane_id;
191-
if (at::_isnan(values_data[starting_index]) ||
192-
values_data[starting_index] == output_data[output_index]) {
193-
grad_input_data[starting_index] = grad_data[output_index];
220+
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
221+
+ j * data_stride_axis + lane_id;
222+
if (at::_isnan(values_data[data_index]) ||
223+
values_data[data_index] == output_data[output_index]) {
224+
grad_input_data[data_index] = grad_data[output_index];
194225
counter++;
195226
}
196227
}
@@ -200,47 +231,51 @@ __global__ void segment_reduce_backward_kernel(
200231
return;
201232
}
202233
for (int64_t j = offset_start; j < offset_end; ++j) {
203-
int64_t starting_index = (j * stride_count) + lane_id;
204-
if (grad_input_data[starting_index] > 0) {
205-
grad_input_data[starting_index] =
206-
grad_input_data[starting_index] / counter;
234+
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
235+
+ j * data_stride_axis + lane_id;
236+
if (grad_input_data[data_index] > 0) {
237+
grad_input_data[data_index] =
238+
grad_input_data[data_index] / counter;
207239
}
208240
}
209241
} else if (reduction == SegmentReductionType::MEAN) {
210-
auto grad_val = grad_data[output_index] / lengths_data[row_id];
242+
auto grad_val = grad_data[output_index] / segment_length;
211243
for (int64_t j = offset_start; j < offset_end; ++j) {
212-
int64_t starting_index = (j * stride_count) + lane_id;
213-
grad_input_data[starting_index] = grad_val;
244+
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
245+
+ j * data_stride_axis + lane_id;
246+
grad_input_data[data_index] = grad_val;
214247
}
215248
} else if (reduction == SegmentReductionType::SUM) {
216249
const auto& grad_val = grad_data[output_index];
217250
for (int64_t j = offset_start; j < offset_end; ++j) {
218-
int64_t starting_index = (j * stride_count) + lane_id;
219-
grad_input_data[starting_index] = grad_val;
251+
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
252+
+ j * data_stride_axis + lane_id;
253+
grad_input_data[data_index] = grad_val;
220254
}
221255
} else if (reduction == SegmentReductionType::PROD) {
222256
const auto& grad_val = grad_data[output_index] * output_data[output_index];
223257
for (int64_t j = offset_start; j < offset_end; ++j) {
224-
int64_t starting_index = (j * stride_count) + lane_id;
225-
if (at::_isnan(values_data[starting_index]) ||
226-
values_data[starting_index] == 0) {
258+
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
259+
+ j * data_stride_axis + lane_id;
260+
if (at::_isnan(values_data[data_index]) ||
261+
values_data[data_index] == 0) {
227262
// explicitly compute exclusive prod
228263
scalar_t exclusive_prod = initial_prod_value;
229-
int64_t idx;
264+
int64_t prod_idx;
230265
for (int64_t k = offset_start; k < offset_end; ++k) {
231266
if (k != j) {
232-
idx = (k * stride_count) + lane_id;
233-
exclusive_prod *= values_data[idx];
267+
prod_idx = outer_idx * data_stride_axis * data_size_axis
268+
+ k * data_stride_axis + lane_id;
269+
exclusive_prod *= values_data[prod_idx];
234270
}
235271
}
236-
grad_input_data[starting_index] = grad_data[output_index] * exclusive_prod;
272+
grad_input_data[data_index] = grad_data[output_index] * exclusive_prod;
237273
} else {
238-
grad_input_data[starting_index] = grad_val / values_data[starting_index];
274+
grad_input_data[data_index] = grad_val / values_data[data_index];
239275
}
240276
}
241277
}
242278
}
243-
244279
} // namespace
245280

246281
Tensor _segment_reduce_cuda_backward_kernel(
@@ -251,28 +286,43 @@ Tensor _segment_reduce_cuda_backward_kernel(
251286
const Tensor& lengths_contig,
252287
int64_t axis,
253288
const c10::optional<Scalar>& initial) {
254-
int64_t segment_count = lengths_contig.numel();
255-
auto output_shape = data_contig.sizes().vec();
256-
output_shape[axis] = segment_count;
289+
axis = lengths_contig.dim() - 1;
290+
int64_t segment_count = lengths_contig.size(axis);
291+
int64_t lengths_stride_axis = lengths_contig.stride(axis);
257292
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
258293

259-
int64_t stride_count = data_contig.numel() / data_contig.size(axis);
294+
auto zeros_shape = lengths_contig.sizes().vec();
295+
zeros_shape[axis] = 1;
296+
auto offsets = at::cat({at::zeros(zeros_shape, lengths_contig.options()), lengths_contig}, axis);
297+
offsets.cumsum_(axis);
260298

261-
auto offsets = _get_complete_sum(lengths_contig);
299+
// outer_offset is the size of the outer dimensions of output (before axis)
300+
// inner_offset is the size of the inner dimensions of output (after axis)
301+
int64_t outer_offset = 1, inner_offset = 1;
302+
for (int64_t d = 0; d < axis; d++) {
303+
outer_offset *= output_contig.size(d);
304+
}
305+
for (int64_t d = axis + 1; d < output_contig.dim(); d++) {
306+
inner_offset *= output_contig.size(d);
307+
}
262308

263309
constexpr int threads_per_block = 256;
264-
int64_t num_blocks =
265-
((segment_count * stride_count) + threads_per_block - 1) /
266-
threads_per_block;
310+
int64_t num_blocks = (outer_offset * inner_offset * segment_count + threads_per_block - 1) / threads_per_block;
267311

268312
num_blocks = std::max(num_blocks, (int64_t)1);
269313

314+
auto data_stride_axis = data_contig.stride(axis);
315+
auto data_size_axis = data_contig.size(axis);
316+
auto output_stride_axis = output_contig.stride(axis);
317+
auto output_size_axis = output_contig.size(axis);
318+
auto offsets_stride_axis = offsets.stride(axis);
319+
270320
AT_DISPATCH_INDEX_TYPES(
271-
lengths_contig.type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
321+
lengths_contig.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
272322
const auto* lengths_data = lengths_contig.data_ptr<index_t>();
273323
auto* offsets_data = offsets.data_ptr<index_t>();
274324

275-
// TODO: Swtich to TensorIterator for better maintainablility and
325+
// TODO: Switch to TensorIterator for better maintainablility and
276326
// readability
277327
AT_DISPATCH_FLOATING_TYPES_AND2(
278328
kBFloat16,
@@ -305,8 +355,16 @@ Tensor _segment_reduce_cuda_backward_kernel(
305355
lengths_data,
306356
offsets_data,
307357
segment_count,
308-
stride_count,
309-
initial_prod_value);
358+
lengths_stride_axis,
359+
initial_prod_value,
360+
outer_offset,
361+
inner_offset,
362+
data_stride_axis,
363+
data_size_axis,
364+
output_stride_axis,
365+
output_size_axis,
366+
offsets_stride_axis
367+
);
310368
C10_CUDA_KERNEL_LAUNCH_CHECK();
311369
}));
312370
}));
@@ -319,24 +377,46 @@ Tensor _segment_reduce_cuda_kernel(
319377
const Tensor& lengths,
320378
int64_t axis,
321379
const c10::optional<Scalar>& initial) {
322-
int64_t segment_count = lengths.numel();
380+
// data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel
381+
TORCH_CHECK(data.is_contiguous(), "Expected data to be contiguous.");
382+
TORCH_CHECK(lengths.is_contiguous(), "Expected lengths to be contiguous.");
383+
axis = lengths.dim() - 1;
384+
int64_t segment_count = lengths.size(axis);
385+
int64_t lengths_stride_axis = lengths.stride(axis);
323386
auto output_shape = data.sizes().vec();
324387
output_shape[axis] = segment_count;
325388
auto output = at::empty(output_shape, data.options());
326389
327-
int64_t stride_count = data.numel() / data.size(axis);
328-
329-
auto offsets = _get_complete_sum(lengths);
390+
// _get_complete_sum only supports 1D?
391+
auto zeros_shape = lengths.sizes().vec();
392+
zeros_shape[axis] = 1;
393+
auto offsets = at::cat({at::zeros(zeros_shape, lengths.options()), lengths}, axis);
394+
offsets.cumsum_(axis);
395+
396+
// outer_offset is the size of the outer dimensions of output (before axis)
397+
// inner_offset is the size of the inner dimensions of output (after axis)
398+
int64_t outer_offset = 1, inner_offset = 1;
399+
for (int64_t d = 0; d < axis; d++) {
400+
outer_offset *= output.size(d);
401+
}
402+
for (int64_t d = axis + 1; d < output.dim(); d++) {
403+
inner_offset *= output.size(d);
404+
}
330405
331406
constexpr int threads_per_block = 256;
332-
int64_t num_blocks =
333-
((segment_count * stride_count) + threads_per_block - 1) /
334-
threads_per_block;
407+
// segment_count * stride_count is just output.numel() ?
408+
int64_t num_blocks = (output.numel() + threads_per_block - 1) / threads_per_block;
335409
336410
num_blocks = std::max(num_blocks, (int64_t)1);
337411
412+
auto data_stride_axis = data.stride(axis);
413+
auto data_size_axis = data.size(axis);
414+
auto output_stride_axis = output.stride(axis);
415+
auto output_size_axis = output.size(axis);
416+
auto offsets_stride_axis = offsets.stride(axis);
417+
338418
AT_DISPATCH_INDEX_TYPES(
339-
lengths.type(), "_segment_reduce_cuda_kernel1", ([&] {
419+
lengths.scalar_type(), "_segment_reduce_cuda_kernel1", ([&] {
340420
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
341421
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
342422
AT_DISPATCH_FLOATING_TYPES_AND2(
@@ -376,9 +456,17 @@ Tensor _segment_reduce_cuda_kernel(
376456
lengths_data_ptr,
377457
offsets_data_ptr,
378458
segment_count,
379-
stride_count,
459+
lengths_stride_axis,
380460
initial.has_value(),
381-
initial_value);
461+
initial_value,
462+
outer_offset,
463+
inner_offset,
464+
data_stride_axis,
465+
data_size_axis,
466+
output_stride_axis,
467+
output_size_axis,
468+
offsets_stride_axis
469+
);
382470
C10_CUDA_KERNEL_LAUNCH_CHECK();
383471
} else {
384472
if (reduction == SegmentReductionType::MAX) {

test/test_ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1547,7 +1547,7 @@ def test_refs_are_in_python_ref_db(self, op):
15471547
"to_sparse", # Could not run 'aten::to_sparse' with arguments from the 'Meta' backend
15481548
"tensor_split", # The tensor has a non-zero number of elements, but its data is not allocated yet
15491549
"repeat_interleave", # cannot repeat_interleave a meta tensor without output_size
1550-
"segment_reduce", # Could not run 'aten::segment_reduce' with arguments from the 'Meta' backend.
1550+
"segment_reduce.lengths", # Could not run 'aten::segment_reduce' with arguments from the 'Meta' backend.
15511551
"sparse.sampled.addmm", # sparsity not supported
15521552
# Can not infer total number of classes from meta. no way at present to throw DynamicOutputShapeException
15531553
"nn.functional.one_hot",

0 commit comments

Comments
 (0)