Skip to content

Commit 6eaf791

Browse files
committedMar 20, 2025
BroadcastIndexesRange: leading 1s don't require true broadcasting
Moved the mechanism we use to detect broadcasting from optimized/util/binary_ops.h ghstack-source-id: 3af3983e3474e86e20267f755def81d74893b9b6 ghstack-comment-id: 2738665656 Pull Request resolved: #9431
1 parent 47bdd56 commit 6eaf791

File tree

3 files changed

+31
-26
lines changed

3 files changed

+31
-26
lines changed
 

‎kernels/optimized/cpu/binary_ops.h

+1-24
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,11 @@
1010

1111
#include <executorch/kernels/optimized/vec/functional.h>
1212
#include <executorch/kernels/portable/cpu/scalar_utils.h>
13+
#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
1314
#include <executorch/runtime/kernel/kernel_includes.h>
1415

1516
namespace torch {
1617
namespace executor {
17-
namespace internal {
18-
// NOTE: we bake ArrayRef iterators being pointers into the return
19-
// type here because we assume that iterators are portable across
20-
// ArrayRef copies.
21-
inline const Tensor::SizesType* arrayref_begin_ignoring_leading_1s(
22-
ArrayRef<Tensor::SizesType> arr) {
23-
return std::find_if(
24-
arr.begin(), arr.end(), [](Tensor::SizesType x) { return x != 1; });
25-
}
26-
27-
inline bool sizes_match_ignoring_leading_1s(
28-
ArrayRef<Tensor::SizesType> lhs,
29-
ArrayRef<Tensor::SizesType> rhs) {
30-
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs);
31-
auto lhs_end = lhs.end();
32-
33-
auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs);
34-
auto rhs_end = rhs.end();
35-
36-
return ((lhs_end - lhs_begin) == (rhs_end - rhs_begin)) &&
37-
std::equal(lhs_begin, lhs_end, rhs_begin);
38-
}
39-
} // namespace internal
40-
4118
enum class ElementwiseOptimizedPath {
4219
kNone,
4320
kTreatAs1d,

‎kernels/optimized/cpu/targets.bzl

+4-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,10 @@ def define_common_targets():
130130
srcs = [],
131131
exported_headers = ["op_add_sub_impl.h"],
132132
visibility = ["//executorch/kernels/optimized/cpu/..."],
133-
exported_deps = ["//executorch/runtime/core:core"],
133+
exported_deps = [
134+
"//executorch/runtime/core:core",
135+
"//executorch/kernels/portable/cpu/util:broadcast_indexes_range",
136+
],
134137
)
135138

136139
runtime.cxx_library(

‎kernels/portable/cpu/util/broadcast_indexes_range.h

+26-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,28 @@
2121
namespace torch::executor {
2222

2323
namespace internal {
24+
// NOTE: we bake ArrayRef iterators being pointers into the return
25+
// type here because we assume that iterators are portable across
26+
// ArrayRef copies.
27+
inline const Tensor::SizesType* arrayref_begin_ignoring_leading_1s(
28+
ArrayRef<Tensor::SizesType> arr) {
29+
return std::find_if(
30+
arr.begin(), arr.end(), [](Tensor::SizesType x) { return x != 1; });
31+
}
32+
33+
inline bool sizes_match_ignoring_leading_1s(
34+
ArrayRef<Tensor::SizesType> lhs,
35+
ArrayRef<Tensor::SizesType> rhs) {
36+
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs);
37+
auto lhs_end = lhs.end();
38+
39+
auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs);
40+
auto rhs_end = rhs.end();
41+
42+
return ((lhs_end - lhs_begin) == (rhs_end - rhs_begin)) &&
43+
std::equal(lhs_begin, lhs_end, rhs_begin);
44+
}
45+
2446
template <std::size_t kNumInputs>
2547
class BroadcastIndexesIterator {
2648
public:
@@ -35,7 +57,10 @@ class BroadcastIndexesIterator {
3557
template <typename... Args>
3658
explicit BroadcastIndexesIterator(const Tensor& output, const Args&... args)
3759
: output_dim_or_zero_if_no_broadcasting_(
38-
((args.sizes() == output.sizes()) && ...) ? 0 : output.dim()),
60+
(sizes_match_ignoring_leading_1s(args.sizes(), output.sizes()) &&
61+
...)
62+
? 0
63+
: output.dim()),
3964
output_shape_(output.sizes()) {
4065
static_assert(
4166
sizeof...(args) == kNumInputs && (std::is_same_v<Args, Tensor> && ...),

0 commit comments

Comments
 (0)