15
15
#include < executorch/runtime/kernel/kernel_runtime_context.h>
16
16
#include < executorch/runtime/kernel/thread_parallel_interface.h>
17
17
18
+ #ifdef ET_USE_PYTORCH_HEADERS
19
+ #include < ATen/cpu/vec/vec.h>
20
+ #endif // ET_USE_PYTORCH_HEADERS
21
+
18
22
#include < array>
19
23
#include < utility>
20
24
@@ -58,6 +62,19 @@ template <typename CTYPE_COMMON, typename Op, typename... Args>
58
62
using op_call_result =
59
63
std::invoke_result_t <Op, ignore_first_yield_second<Args, CTYPE_COMMON>...>;
60
64
65
+ #ifdef ET_USE_PYTORCH_HEADERS
66
+ // Can I call a function of type Op with sizeof...(Args) arguments of type
67
+ // at::vec::Vectorized<CTYPE_COMMON>?
68
+ //
69
+ // See [NOTE: Generic lambdas] below for requirements on Op.
70
+ template <typename CTYPE_COMMON, typename Op, typename ... Args>
71
+ constexpr bool can_use_vectorized () {
72
+ return std::is_invocable_v<
73
+ Op,
74
+ ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>;
75
+ }
76
+ #endif // ET_USE_PYTORCH_HEADERS
77
+
61
78
template <
62
79
typename CTYPE_COMMON,
63
80
typename CTYPE_OUT,
@@ -68,14 +85,72 @@ inline void dtype_specialized_elementwise_fn_impl(
68
85
KernelRuntimeContext& ctx,
69
86
const Tensor& out,
70
87
Args... inputs) {
88
+ static_assert (
89
+ (std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
90
+ ...));
71
91
constexpr auto kNumInputs = sizeof ...(inputs);
72
- ET_DCHECK (((inputs.first ->element_size () == sizeof (CTYPE_COMMON)) && ...));
92
+ // All inputs must be of type CTYPE_COMMON.
93
+ ET_DCHECK (
94
+ ((inputs.first ->scalar_type () ==
95
+ CppTypeToScalarType<CTYPE_COMMON>::value) &&
96
+ ...));
73
97
74
98
std::array<const CTYPE_COMMON*, kNumInputs > inputs_data_ptrs = {
75
99
inputs.first ->template const_data_ptr <CTYPE_COMMON>()...};
76
100
77
101
CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
78
102
103
+ #ifdef ET_USE_PYTORCH_HEADERS
104
+ if constexpr (can_use_vectorized<CTYPE_COMMON, Op, Args...>::value) {
105
+ const bool any_is_broadcasted =
106
+ !(torch::executor::internal::sizes_match_ignoring_leading_1s (
107
+ inputs.first ->sizes (), out.sizes ()) &&
108
+ ...);
109
+ if (!any_is_broadcasted) {
110
+ using Vec = at::vec::Vectorized<CTYPE_COMMON>;
111
+ ::executorch::extension::parallel_for (
112
+ 0 ,
113
+ out.numel(),
114
+ ::executorch::extension::internal::GRAIN_SIZE,
115
+ [&](const auto begin, const auto end) {
116
+ const auto vectorized_begin =
117
+ begin + (Vec::size () - begin % Vec::size ()) % Vec::size ();
118
+ const auto vectorized_end = end - (end % Vec::size ());
119
+ // Scalar prologue.
120
+ for (const auto idx : c10::irange (begin, vectorized_begin)) {
121
+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
122
+ for (const auto input_idx : c10::irange (kNumInputs )) {
123
+ loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
124
+ }
125
+ data_out[idx] = std::apply (compute_fun, loaded_inputs);
126
+ }
127
+
128
+ // Main vectorized loop.
129
+ for (auto idx = vectorized_begin; idx < vectorized_end;
130
+ idx += Vec::size ()) {
131
+ std::array<Vec, kNumInputs > loaded_vec_inputs;
132
+ for (const auto input_idx : c10::irange (kNumInputs )) {
133
+ loaded_vec_inputs[input_idx] =
134
+ Vec::loadu (&inputs_data_ptrs[input_idx][idx]);
135
+ }
136
+ auto result_vec = std::apply (compute_fun, loaded_vec_inputs);
137
+ result_vec.store (&data_out[idx]);
138
+ }
139
+
140
+ // Scalar epilogue.
141
+ for (const auto idx : c10::irange (vectorized_end, end)) {
142
+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
143
+ for (const auto input_idx : c10::irange (kNumInputs )) {
144
+ loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
145
+ }
146
+ data_out[idx] = std::apply (compute_fun, loaded_inputs);
147
+ }
148
+ });
149
+ return ;
150
+ }
151
+ }
152
+ #endif
153
+
79
154
::executorch::extension::parallel_for (
80
155
0 ,
81
156
out.numel(),
@@ -255,6 +330,17 @@ inline void apply_unitensor_elementwise_fn(
255
330
compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
256
331
}
257
332
333
+ /* *
334
+ * Useful for unary elementwise operators. For each element of the
335
+ * input, call Op and write to the corresponding element of the
336
+ * output. Tensor broadcasting is applied wherever it is required.
337
+ *
338
+ * [NOTE: Generic lambdas]: If Op is a *generic* lambda (i.e., one with `auto`
339
+ * parameters; normal lambdas are fine), it must fulfill one of the
340
+ * following conditions. Either:
341
+ * 1) It must in fact compile when passed at::vec::Vectorized<CTYPE_COMMON>, or
342
+ * 2) It must be actively SFINAE-friendly, as per the C++17 examples in https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable .
343
+ */
258
344
template <
259
345
typename CTYPE_COMMON,
260
346
const char * op_name,
@@ -296,6 +382,7 @@ inline void apply_bitensor_elementwise_fn(
296
382
* Useful for bi-tensor elementwise operators. For each element of the inputs,
297
383
* perform a computation and write to the corresponding element of the output.
298
384
* Tensor broadcasting is applied wherever it is required.
385
+ * See [NOTE: Generic lambdas] if you want to pass a generic lambda for compute_fun.
299
386
*/
300
387
template <
301
388
typename CTYPE_COMMON,
@@ -362,6 +449,8 @@ inline void apply_tritensor_elementwise_fn(
362
449
*
363
450
* static constexpr const char op_name[] = "my_op";
364
451
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
452
+ *
453
+ * See [NOTE: Generic lambdas] if you want to pass a generic lambda for compute_fun.
365
454
*/
366
455
template <
367
456
typename CTYPE_COMMON,
0 commit comments