Skip to content

Commit 7f8cfb4

Browse files
committed
Add vectorization in elementwise_util (not working yet)
this works with op_mul, which is vectorized-friendly, but doesn't work when we roll out to pattern.h because those ops will not work with Vectorized yet. See TODO in elementwise_util.h ghstack-source-id: d546d5d595929e84814aa38833c8a07bf3cf6ec5 ghstack-comment-id: 2738665976 Pull Request resolved: #9432
1 parent 6eaf791 commit 7f8cfb4

File tree

6 files changed

+104
-14
lines changed

6 files changed

+104
-14
lines changed

.lintrunner.toml

+2
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ exclude_patterns = [
264264
'examples/**',
265265
'exir/verification/bindings.cpp',
266266
'extension/**',
267+
# Uses properly-gated (ET_USE_PYTORCH_HEADERS) ATen include.
268+
'kernels/portable/cpu/util/elementwise_util.h',
267269
'kernels/optimized/**',
268270
'runtime/core/exec_aten/**',
269271
# Want to be able to keep c10 in sync with PyTorch core.

kernels/portable/cpu/op_mul.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ Tensor& mul_out(
5656
CTYPE_COMPUTE,
5757
op_name,
5858
utils::SupportedTensorDtypes::REALHBBF16>(
59-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
60-
return val_a * val_b;
61-
},
59+
[](const auto val_a, const auto val_b) { return val_a * val_b; },
6260
ctx,
6361
a,
6462
utils::SupportedTensorDtypes::REALHBBF16,

kernels/portable/cpu/pattern/pattern.h

+6-9
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,12 @@ Tensor& unary_ufunc_realh(
8080
ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out);
8181

8282
ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] {
83-
utils::apply_unitensor_elementwise_fn<CTYPE, op_name>(
83+
utils::apply_unitensor_elementwise_fn<CTYPE, op_name, utils::SupportedTensorDtypes::SAME_AS_COMMON>(
8484
fn,
8585
ctx,
8686
in,
8787
utils::SupportedTensorDtypes::REALH,
88-
out,
89-
utils::SupportedTensorDtypes::SAME_AS_COMMON);
88+
out);
9089
});
9190
return out;
9291
}
@@ -107,13 +106,12 @@ Tensor& unary_ufunc_realhb_to_bool(
107106
return out;
108107
}
109108
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
110-
utils::apply_unitensor_elementwise_fn<CTYPE_IN, op_name>(
109+
utils::apply_unitensor_elementwise_fn<CTYPE_IN, op_name, utils::SupportedTensorDtypes::BOOL>(
111110
[fn](const CTYPE_IN val_in) { return fn(val_in); },
112111
ctx,
113112
in,
114113
utils::SupportedTensorDtypes::REALHBBF16,
115-
out,
116-
utils::SupportedTensorDtypes::BOOL);
114+
out);
117115
});
118116

119117
return out;
@@ -138,13 +136,12 @@ Tensor& unary_ufunc_realhbbf16_to_floathbf16(
138136
}
139137

140138
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
141-
utils::apply_unitensor_elementwise_fn<CTYPE_IN, op_name>(
139+
utils::apply_unitensor_elementwise_fn<CTYPE_IN, op_name, utils::SupportedTensorDtypes::FLOATHBF16>(
142140
[fn](const CTYPE_IN val_in) { return fn(val_in); },
143141
ctx,
144142
in,
145143
utils::SupportedTensorDtypes::REALHBBF16,
146-
out,
147-
utils::SupportedTensorDtypes::FLOATHBF16);
144+
out);
148145
});
149146

150147
return out;

kernels/portable/cpu/util/elementwise_util.h

+90-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
#include <executorch/runtime/kernel/kernel_runtime_context.h>
1616
#include <executorch/runtime/kernel/thread_parallel_interface.h>
1717

18+
#ifdef ET_USE_PYTORCH_HEADERS
19+
#include <ATen/cpu/vec/vec.h>
20+
#endif // ET_USE_PYTORCH_HEADERS
21+
1822
#include <array>
1923
#include <utility>
2024

@@ -58,6 +62,19 @@ template <typename CTYPE_COMMON, typename Op, typename... Args>
5862
using op_call_result =
5963
std::invoke_result_t<Op, ignore_first_yield_second<Args, CTYPE_COMMON>...>;
6064

65+
#ifdef ET_USE_PYTORCH_HEADERS
66+
// Can I call a function of type Op with sizeof...(Args) arguments of type
67+
// at::vec::Vectorized<CTYPE_COMMON>?
68+
//
69+
// See [NOTE: Generic lambdas] below for requirements on Op.
70+
template <typename CTYPE_COMMON, typename Op, typename... Args>
71+
constexpr bool can_use_vectorized() {
72+
return std::is_invocable_v<
73+
Op,
74+
ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>;
75+
}
76+
#endif // ET_USE_PYTORCH_HEADERS
77+
6178
template <
6279
typename CTYPE_COMMON,
6380
typename CTYPE_OUT,
@@ -68,14 +85,72 @@ inline void dtype_specialized_elementwise_fn_impl(
6885
KernelRuntimeContext& ctx,
6986
const Tensor& out,
7087
Args... inputs) {
88+
static_assert(
89+
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
90+
...));
7191
constexpr auto kNumInputs = sizeof...(inputs);
72-
ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMMON)) && ...));
92+
// All inputs must be of type CTYPE_COMMON.
93+
ET_DCHECK(
94+
((inputs.first->scalar_type() ==
95+
CppTypeToScalarType<CTYPE_COMMON>::value) &&
96+
...));
7397

7498
std::array<const CTYPE_COMMON*, kNumInputs> inputs_data_ptrs = {
7599
inputs.first->template const_data_ptr<CTYPE_COMMON>()...};
76100

77101
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
78102

103+
#ifdef ET_USE_PYTORCH_HEADERS
104+
if constexpr (can_use_vectorized<CTYPE_COMMON, Op, Args...>::value) {
105+
const bool any_is_broadcasted =
106+
!(torch::executor::internal::sizes_match_ignoring_leading_1s(
107+
inputs.first->sizes(), out.sizes()) &&
108+
...);
109+
if (!any_is_broadcasted) {
110+
using Vec = at::vec::Vectorized<CTYPE_COMMON>;
111+
::executorch::extension::parallel_for(
112+
0,
113+
out.numel(),
114+
::executorch::extension::internal::GRAIN_SIZE,
115+
[&](const auto begin, const auto end) {
116+
const auto vectorized_begin =
117+
begin + (Vec::size() - begin % Vec::size()) % Vec::size();
118+
const auto vectorized_end = end - (end % Vec::size());
119+
// Scalar prologue.
120+
for (const auto idx : c10::irange(begin, vectorized_begin)) {
121+
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
122+
for (const auto input_idx : c10::irange(kNumInputs)) {
123+
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
124+
}
125+
data_out[idx] = std::apply(compute_fun, loaded_inputs);
126+
}
127+
128+
// Main vectorized loop.
129+
for (auto idx = vectorized_begin; idx < vectorized_end;
130+
idx += Vec::size()) {
131+
std::array<Vec, kNumInputs> loaded_vec_inputs;
132+
for (const auto input_idx : c10::irange(kNumInputs)) {
133+
loaded_vec_inputs[input_idx] =
134+
Vec::loadu(&inputs_data_ptrs[input_idx][idx]);
135+
}
136+
auto result_vec = std::apply(compute_fun, loaded_vec_inputs);
137+
result_vec.store(&data_out[idx]);
138+
}
139+
140+
// Scalar epilogue.
141+
for (const auto idx : c10::irange(vectorized_end, end)) {
142+
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
143+
for (const auto input_idx : c10::irange(kNumInputs)) {
144+
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
145+
}
146+
data_out[idx] = std::apply(compute_fun, loaded_inputs);
147+
}
148+
});
149+
return;
150+
}
151+
}
152+
#endif
153+
79154
::executorch::extension::parallel_for(
80155
0,
81156
out.numel(),
@@ -255,6 +330,17 @@ inline void apply_unitensor_elementwise_fn(
255330
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
256331
}
257332

333+
/**
334+
* Useful for unary elementwise operators. For each element of the
335+
* input, call Op and write to the corresponding element of the
336+
* output. Tensor broadcasting is applied wherever it is required.
337+
*
338+
* [NOTE: Generic lambdas]: If Op is a *generic* lambda (i.e., one with `auto`
339+
* parameters; normal lambdas are fine), it must fulfill one of the
340+
* following conditions. Either:
341+
* 1) It must in fact compile when passed at::vec::Vectorized<CTYPE_COMMON>, or
342+
* 2) It must be actively SFINAE-friendly, as per the C++17 examples in https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable .
343+
*/
258344
template <
259345
typename CTYPE_COMMON,
260346
const char* op_name,
@@ -296,6 +382,7 @@ inline void apply_bitensor_elementwise_fn(
296382
* Useful for bi-tensor elementwise operators. For each element of the inputs,
297383
* perform a computation and write to the corresponding element of the output.
298384
* Tensor broadcasting is applied wherever it is required.
385+
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for compute_fun.
299386
*/
300387
template <
301388
typename CTYPE_COMMON,
@@ -362,6 +449,8 @@ inline void apply_tritensor_elementwise_fn(
362449
*
363450
* static constexpr const char op_name[] = "my_op";
364451
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
452+
*
453+
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for compute_fun.
365454
*/
366455
template <
367456
typename CTYPE_COMMON,

kernels/portable/cpu/util/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def define_common_targets():
110110
":broadcast_indexes_range",
111111
":broadcast_util",
112112
":dtype_util",
113+
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
113114
"//executorch/runtime/kernel:kernel_runtime_context",
114115
"//executorch/runtime/kernel:thread_parallel_interface",
115116
],

runtime/core/portable_type/c10/c10/targets.bzl

+4-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ def define_common_targets():
4949
runtime.cxx_library(
5050
name = "aten_headers_for_executorch",
5151
srcs = [],
52-
visibility = ["//executorch/kernels/optimized/..."],
52+
visibility = [
53+
"//executorch/kernels/optimized/...",
54+
"//executorch/kernels/portable/cpu/util/...",
55+
],
5356
exported_deps = select({
5457
"DEFAULT": [],
5558
"ovr_config//cpu:arm64": [

0 commit comments

Comments
 (0)