Skip to content

Commit 1c99485

Browse files
committed
Add vectorization in elementwise_util (not working yet)
this works with op_mul, which is vectorized-friendly, but doesn't work when we roll out to pattern.h because those ops will not work with Vectorized yet. See TODO in elementwise_util.h ghstack-source-id: 30d2311bed080c3a5390ab00ca20a1e33563f077 ghstack-comment-id: 2738665976 Pull Request resolved: #9432
1 parent 4fbbf97 commit 1c99485

File tree

6 files changed

+109
-14
lines changed

6 files changed

+109
-14
lines changed

.lintrunner.toml

+2
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ exclude_patterns = [
264264
'examples/**',
265265
'exir/verification/bindings.cpp',
266266
'extension/**',
267+
# Uses properly-gated (ET_USE_PYTORCH_HEADERS) ATen include.
268+
'kernels/portable/cpu/util/elementwise_util.h',
267269
'kernels/optimized/**',
268270
'runtime/core/exec_aten/**',
269271
# Want to be able to keep c10 in sync with PyTorch core.

kernels/portable/cpu/op_mul.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ Tensor& mul_out(
5656
CTYPE_COMPUTE,
5757
op_name,
5858
utils::SupportedTensorDtypes::REALHBBF16>(
59-
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
60-
return val_a * val_b;
61-
},
59+
[](const auto val_a, const auto val_b) { return val_a * val_b; },
6260
ctx,
6361
a,
6462
utils::SupportedTensorDtypes::REALHBBF16,

kernels/portable/cpu/pattern/pattern.h

+6-9
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,12 @@ Tensor& unary_ufunc_realh(
8080
ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out);
8181

8282
ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] {
83-
utils::apply_unitensor_elementwise_fn<CTYPE, op_name>(
83+
utils::apply_unitensor_elementwise_fn<CTYPE, op_name, utils::SupportedTensorDtypes::SAME_AS_COMMON>(
8484
fn,
8585
ctx,
8686
in,
8787
utils::SupportedTensorDtypes::REALH,
88-
out,
89-
utils::SupportedTensorDtypes::SAME_AS_COMMON);
88+
out);
9089
});
9190
return out;
9291
}
@@ -107,13 +106,12 @@ Tensor& unary_ufunc_realhb_to_bool(
107106
return out;
108107
}
109108
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
110-
utils::apply_unitensor_elementwise_fn<CTYPE_IN, op_name>(
109+
utils::apply_unitensor_elementwise_fn<CTYPE_IN, op_name, utils::SupportedTensorDtypes::BOOL>(
111110
[fn](const CTYPE_IN val_in) { return fn(val_in); },
112111
ctx,
113112
in,
114113
utils::SupportedTensorDtypes::REALHBBF16,
115-
out,
116-
utils::SupportedTensorDtypes::BOOL);
114+
out);
117115
});
118116

119117
return out;
@@ -138,13 +136,12 @@ Tensor& unary_ufunc_realhbbf16_to_floathbf16(
138136
}
139137

140138
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
141-
utils::apply_unitensor_elementwise_fn<CTYPE_IN, op_name>(
139+
utils::apply_unitensor_elementwise_fn<CTYPE_IN, op_name, utils::SupportedTensorDtypes::FLOATHBF16>(
142140
[fn](const CTYPE_IN val_in) { return fn(val_in); },
143141
ctx,
144142
in,
145143
utils::SupportedTensorDtypes::REALHBBF16,
146-
out,
147-
utils::SupportedTensorDtypes::FLOATHBF16);
144+
out);
148145
});
149146

150147
return out;

kernels/portable/cpu/util/elementwise_util.h

+95-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
#include <executorch/runtime/kernel/kernel_runtime_context.h>
1616
#include <executorch/runtime/kernel/thread_parallel_interface.h>
1717

18+
#ifdef ET_USE_PYTORCH_HEADERS
19+
#include <ATen/cpu/vec/vec.h>
20+
#endif // ET_USE_PYTORCH_HEADERS
21+
1822
#include <array>
1923
#include <utility>
2024

@@ -58,6 +62,38 @@ template <typename CTYPE_COMMON, typename Op, typename... Args>
5862
using op_call_result =
5963
std::invoke_result_t<Op, ignore_first_yield_second<Args, CTYPE_COMMON>...>;
6064

65+
#ifdef ET_USE_PYTORCH_HEADERS
66+
template <typename T>
67+
struct is_vectorized : public std::false_type {};
68+
69+
template <typename T>
70+
struct is_vectorized<at::vec::Vectorized<T>> : public std::true_type {};
71+
72+
// TODO: can_use_vectorized and can_use_vectorized_impl are a failed
73+
// attempt to use SFINAE to detect whether our generic lambda argument
74+
// with deduced return type would compile if it was passed
75+
// Vectorized<CTYPE_COMMON> instead of CTYPE_COMMON. SFINAE does not
76+
// work that way (see
77+
// e.g. https://stackoverflow.com/questions/53344484/hard-error-when-using-stdinvoke-result-t-with-a-generic-lambda,
78+
// https://stackoverflow.com/questions/31368601/how-to-detect-if-a-generic-lambda-is-uncompilable-in-c-14);
79+
// if we really want to do it then we need to at least require that
80+
// our lambdas actively participate in being SFINAE-friendly, as in
81+
// https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable.
82+
template <typename CTYPE_COMMON, typename Op, typename Enable=void, typename... Args>
83+
struct can_use_vectorized_impl : std::false_type {};
84+
template <typename CTYPE_COMMON, typename Op, typename... Args>
85+
struct can_use_vectorized_impl<CTYPE_COMMON, Op, typename std::void_t<decltype(std::declval<std::invoke_result_t<
86+
Op,
87+
ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>>().store(std::declval<CTYPE_COMMON*>()))>, Args...> : public std::true_type {};//std::bool_constant<is_vectorized<std::invoke_result_t<Op,ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>>::value> {};
88+
89+
// Can I call a function of type Op with sizeof...(Args) arguments of type
90+
// at::vec::Vectorized<CTYPE_COMMON>?
91+
// This is not possible in C++17 as the code is currently set up; see TODO above.
92+
template <typename CTYPE_COMMON, typename Op, typename...Args>
93+
struct can_use_vectorized : public can_use_vectorized_impl<CTYPE_COMMON, Op, void, Args...> {};
94+
95+
#endif // ET_USE_PYTORCH_HEADERS
96+
6197
template <
6298
typename CTYPE_COMMON,
6399
typename CTYPE_OUT,
@@ -68,14 +104,72 @@ inline void dtype_specialized_elementwise_fn_impl(
68104
KernelRuntimeContext& ctx,
69105
const Tensor& out,
70106
Args... inputs) {
107+
static_assert(
108+
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
109+
...));
71110
constexpr auto kNumInputs = sizeof...(inputs);
72-
ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMMON)) && ...));
111+
// All inputs must be of type CTYPE_COMMON.
112+
ET_DCHECK(
113+
((inputs.first->scalar_type() ==
114+
CppTypeToScalarType<CTYPE_COMMON>::value) &&
115+
...));
73116

74117
std::array<const CTYPE_COMMON*, kNumInputs> inputs_data_ptrs = {
75118
inputs.first->template const_data_ptr<CTYPE_COMMON>()...};
76119

77120
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
78121

122+
#ifdef ET_USE_PYTORCH_HEADERS
123+
if constexpr (can_use_vectorized<CTYPE_COMMON, Op, Args...>::value) {
124+
const bool any_is_broadcasted =
125+
!(torch::executor::internal::sizes_match_ignoring_leading_1s(
126+
inputs.first->sizes(), out.sizes()) &&
127+
...);
128+
if (!any_is_broadcasted) {
129+
using Vec = at::vec::Vectorized<CTYPE_COMMON>;
130+
::executorch::extension::parallel_for(
131+
0,
132+
out.numel(),
133+
::executorch::extension::internal::GRAIN_SIZE,
134+
[&](const auto begin, const auto end) {
135+
const auto vectorized_begin =
136+
begin + (Vec::size() - begin % Vec::size()) % Vec::size();
137+
const auto vectorized_end = end - (end % Vec::size());
138+
// Scalar prologue.
139+
for (const auto idx : c10::irange(begin, vectorized_begin)) {
140+
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
141+
for (const auto input_idx : c10::irange(kNumInputs)) {
142+
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
143+
}
144+
data_out[idx] = std::apply(compute_fun, loaded_inputs);
145+
}
146+
147+
// Main vectorized loop.
148+
for (auto idx = vectorized_begin; idx < vectorized_end;
149+
idx += Vec::size()) {
150+
std::array<Vec, kNumInputs> loaded_vec_inputs;
151+
for (const auto input_idx : c10::irange(kNumInputs)) {
152+
loaded_vec_inputs[input_idx] =
153+
Vec::loadu(&inputs_data_ptrs[input_idx][idx]);
154+
}
155+
auto result_vec = std::apply(compute_fun, loaded_vec_inputs);
156+
result_vec.store(&data_out[idx]);
157+
}
158+
159+
// Scalar epilogue.
160+
for (const auto idx : c10::irange(vectorized_end, end)) {
161+
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
162+
for (const auto input_idx : c10::irange(kNumInputs)) {
163+
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
164+
}
165+
data_out[idx] = std::apply(compute_fun, loaded_inputs);
166+
}
167+
});
168+
return;
169+
}
170+
}
171+
#endif
172+
79173
::executorch::extension::parallel_for(
80174
0,
81175
out.numel(),

kernels/portable/cpu/util/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def define_common_targets():
110110
":broadcast_indexes_range",
111111
":broadcast_util",
112112
":dtype_util",
113+
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
113114
"//executorch/runtime/kernel:kernel_runtime_context",
114115
"//executorch/runtime/kernel:thread_parallel_interface",
115116
],

runtime/core/portable_type/c10/c10/targets.bzl

+4-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ def define_common_targets():
4949
runtime.cxx_library(
5050
name = "aten_headers_for_executorch",
5151
srcs = [],
52-
visibility = ["//executorch/kernels/optimized/..."],
52+
visibility = [
53+
"//executorch/kernels/optimized/...",
54+
"//executorch/kernels/portable/cpu/util/...",
55+
],
5356
exported_deps = select({
5457
"DEFAULT": [],
5558
"ovr_config//cpu:arm64": [

0 commit comments

Comments
 (0)