15
15
#include < executorch/runtime/kernel/kernel_runtime_context.h>
16
16
#include < executorch/runtime/kernel/thread_parallel_interface.h>
17
17
18
+ #ifdef ET_USE_PYTORCH_HEADERS
19
+ #include < ATen/cpu/vec/vec.h>
20
+ #endif // ET_USE_PYTORCH_HEADERS
21
+
18
22
#include < array>
19
23
#include < utility>
20
24
@@ -58,6 +62,38 @@ template <typename CTYPE_COMMON, typename Op, typename... Args>
58
62
using op_call_result =
59
63
std::invoke_result_t <Op, ignore_first_yield_second<Args, CTYPE_COMMON>...>;
60
64
65
+ #ifdef ET_USE_PYTORCH_HEADERS
66
+ template <typename T>
67
+ struct is_vectorized : public std ::false_type {};
68
+
69
+ template <typename T>
70
+ struct is_vectorized <at::vec::Vectorized<T>> : public std::true_type {};
71
+
72
+ // TODO: can_use_vectorized and can_use_vectorized_impl are a failed
73
+ // attempt to use SFINAE to detect whether our generic lambda argument
74
+ // with deduced return type would compile if it was passed
75
+ // Vectorized<CTYPE_COMMON> instead of CTYPE_COMMON. SFINAE does not
76
+ // work that way (see
77
+ // e.g. https://stackoverflow.com/questions/53344484/hard-error-when-using-stdinvoke-result-t-with-a-generic-lambda,
78
+ // https://stackoverflow.com/questions/31368601/how-to-detect-if-a-generic-lambda-is-uncompilable-in-c-14);
79
+ // if we really want to do it then we need to at least require that
80
+ // our lambdas actively participate in being SFINAE-friendly, as in
81
+ // https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable.
82
+ template <typename CTYPE_COMMON, typename Op, typename Enable=void , typename ... Args>
83
+ struct can_use_vectorized_impl : std::false_type {};
84
+ template <typename CTYPE_COMMON, typename Op, typename ... Args>
85
+ struct can_use_vectorized_impl <CTYPE_COMMON, Op, typename std::void_t <decltype(std::declval<std::invoke_result_t <
86
+ Op,
87
+ ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>>().store(std::declval<CTYPE_COMMON*>()))>, Args...> : public std::true_type {};// std::bool_constant<is_vectorized<std::invoke_result_t<Op,ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>>::value> {};
88
+
89
+ // Can I call a function of type Op with sizeof...(Args) arguments of type
90
+ // at::vec::Vectorized<CTYPE_COMMON>?
91
+ // This is not possible in C++17 as the code is currently set up; see TODO above.
92
+ template <typename CTYPE_COMMON, typename Op, typename ...Args>
93
+ struct can_use_vectorized : public can_use_vectorized_impl <CTYPE_COMMON, Op, void , Args...> {};
94
+
95
+ #endif // ET_USE_PYTORCH_HEADERS
96
+
61
97
template <
62
98
typename CTYPE_COMMON,
63
99
typename CTYPE_OUT,
@@ -68,14 +104,72 @@ inline void dtype_specialized_elementwise_fn_impl(
68
104
KernelRuntimeContext& ctx,
69
105
const Tensor& out,
70
106
Args... inputs) {
107
+ static_assert (
108
+ (std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
109
+ ...));
71
110
constexpr auto kNumInputs = sizeof ...(inputs);
72
- ET_DCHECK (((inputs.first ->element_size () == sizeof (CTYPE_COMMON)) && ...));
111
+ // All inputs must be of type CTYPE_COMMON.
112
+ ET_DCHECK (
113
+ ((inputs.first ->scalar_type () ==
114
+ CppTypeToScalarType<CTYPE_COMMON>::value) &&
115
+ ...));
73
116
74
117
std::array<const CTYPE_COMMON*, kNumInputs > inputs_data_ptrs = {
75
118
inputs.first ->template const_data_ptr <CTYPE_COMMON>()...};
76
119
77
120
CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
78
121
122
+ #ifdef ET_USE_PYTORCH_HEADERS
123
+ if constexpr (can_use_vectorized<CTYPE_COMMON, Op, Args...>::value) {
124
+ const bool any_is_broadcasted =
125
+ !(torch::executor::internal::sizes_match_ignoring_leading_1s (
126
+ inputs.first ->sizes (), out.sizes ()) &&
127
+ ...);
128
+ if (!any_is_broadcasted) {
129
+ using Vec = at::vec::Vectorized<CTYPE_COMMON>;
130
+ ::executorch::extension::parallel_for (
131
+ 0 ,
132
+ out.numel(),
133
+ ::executorch::extension::internal::GRAIN_SIZE,
134
+ [&](const auto begin, const auto end) {
135
+ const auto vectorized_begin =
136
+ begin + (Vec::size () - begin % Vec::size ()) % Vec::size ();
137
+ const auto vectorized_end = end - (end % Vec::size ());
138
+ // Scalar prologue.
139
+ for (const auto idx : c10::irange (begin, vectorized_begin)) {
140
+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
141
+ for (const auto input_idx : c10::irange (kNumInputs )) {
142
+ loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
143
+ }
144
+ data_out[idx] = std::apply (compute_fun, loaded_inputs);
145
+ }
146
+
147
+ // Main vectorized loop.
148
+ for (auto idx = vectorized_begin; idx < vectorized_end;
149
+ idx += Vec::size ()) {
150
+ std::array<Vec, kNumInputs > loaded_vec_inputs;
151
+ for (const auto input_idx : c10::irange (kNumInputs )) {
152
+ loaded_vec_inputs[input_idx] =
153
+ Vec::loadu (&inputs_data_ptrs[input_idx][idx]);
154
+ }
155
+ auto result_vec = std::apply (compute_fun, loaded_vec_inputs);
156
+ result_vec.store (&data_out[idx]);
157
+ }
158
+
159
+ // Scalar epilogue.
160
+ for (const auto idx : c10::irange (vectorized_end, end)) {
161
+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
162
+ for (const auto input_idx : c10::irange (kNumInputs )) {
163
+ loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
164
+ }
165
+ data_out[idx] = std::apply (compute_fun, loaded_inputs);
166
+ }
167
+ });
168
+ return ;
169
+ }
170
+ }
171
+ #endif
172
+
79
173
::executorch::extension::parallel_for (
80
174
0 ,
81
175
out.numel(),
0 commit comments