Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vectorization in elementwise_util (not working yet) #9432

Draft
wants to merge 2 commits into
base: gh/swolchok/385/head
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
@@ -264,6 +264,8 @@ exclude_patterns = [
'examples/**',
'exir/verification/bindings.cpp',
'extension/**',
# Uses properly-gated (ET_USE_PYTORCH_HEADERS) ATen include.
'kernels/portable/cpu/util/elementwise_util.h',
'kernels/optimized/**',
'runtime/core/exec_aten/**',
# Want to be able to keep c10 in sync with PyTorch core.
4 changes: 1 addition & 3 deletions kernels/portable/cpu/op_mul.cpp
Original file line number Diff line number Diff line change
@@ -56,9 +56,7 @@ Tensor& mul_out(
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
return val_a * val_b;
},
[](const auto val_a, const auto val_b) { return val_a * val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
15 changes: 6 additions & 9 deletions kernels/portable/cpu/pattern/pattern.h
Original file line number Diff line number Diff line change
@@ -80,13 +80,12 @@ Tensor& unary_ufunc_realh(
ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out);

ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] {
utils::apply_unitensor_elementwise_fn<CTYPE, op_name>(
utils::apply_unitensor_elementwise_fn<CTYPE, op_name, utils::SupportedTensorDtypes::SAME_AS_COMMON>(
fn,
ctx,
in,
utils::SupportedTensorDtypes::REALH,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});
return out;
}
@@ -107,13 +106,12 @@ Tensor& unary_ufunc_realhb_to_bool(
return out;
}
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
utils::apply_unitensor_elementwise_fn<CTYPE_IN, op_name>(
utils::apply_unitensor_elementwise_fn<CTYPE_IN, op_name, utils::SupportedTensorDtypes::BOOL>(
[fn](const CTYPE_IN val_in) { return fn(val_in); },
ctx,
in,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::BOOL);
out);
});

return out;
@@ -138,13 +136,12 @@ Tensor& unary_ufunc_realhbbf16_to_floathbf16(
}

ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
utils::apply_unitensor_elementwise_fn<CTYPE_IN, op_name>(
utils::apply_unitensor_elementwise_fn<CTYPE_IN, op_name, utils::SupportedTensorDtypes::FLOATHBF16>(
[fn](const CTYPE_IN val_in) { return fn(val_in); },
ctx,
in,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::FLOATHBF16);
out);
});

return out;
91 changes: 90 additions & 1 deletion kernels/portable/cpu/util/elementwise_util.h
Original file line number Diff line number Diff line change
@@ -15,6 +15,10 @@
#include <executorch/runtime/kernel/kernel_runtime_context.h>
#include <executorch/runtime/kernel/thread_parallel_interface.h>

#ifdef ET_USE_PYTORCH_HEADERS
#include <ATen/cpu/vec/vec.h>
#endif // ET_USE_PYTORCH_HEADERS

#include <array>
#include <utility>

@@ -58,6 +62,19 @@ template <typename CTYPE_COMMON, typename Op, typename... Args>
using op_call_result =
std::invoke_result_t<Op, ignore_first_yield_second<Args, CTYPE_COMMON>...>;

#ifdef ET_USE_PYTORCH_HEADERS
// Can I call a function of type Op with sizeof...(Args) arguments of type
// at::vec::Vectorized<CTYPE_COMMON>?
//
// See [NOTE: Generic lambdas] below for requirements on Op.
template <typename CTYPE_COMMON, typename Op, typename... Args>
constexpr bool can_use_vectorized() {
return std::is_invocable_v<
Op,
ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>;
}
#endif // ET_USE_PYTORCH_HEADERS

template <
typename CTYPE_COMMON,
typename CTYPE_OUT,
@@ -68,14 +85,72 @@ inline void dtype_specialized_elementwise_fn_impl(
KernelRuntimeContext& ctx,
const Tensor& out,
Args... inputs) {
static_assert(
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
...));
constexpr auto kNumInputs = sizeof...(inputs);
ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMMON)) && ...));
// All inputs must be of type CTYPE_COMMON.
ET_DCHECK(
((inputs.first->scalar_type() ==
CppTypeToScalarType<CTYPE_COMMON>::value) &&
...));

std::array<const CTYPE_COMMON*, kNumInputs> inputs_data_ptrs = {
inputs.first->template const_data_ptr<CTYPE_COMMON>()...};

CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();

#ifdef ET_USE_PYTORCH_HEADERS
if constexpr (can_use_vectorized<CTYPE_COMMON, Op, Args...>::value) {
const bool any_is_broadcasted =
!(torch::executor::internal::sizes_match_ignoring_leading_1s(
inputs.first->sizes(), out.sizes()) &&
...);
if (!any_is_broadcasted) {
using Vec = at::vec::Vectorized<CTYPE_COMMON>;
::executorch::extension::parallel_for(
0,
out.numel(),
::executorch::extension::internal::GRAIN_SIZE,
[&](const auto begin, const auto end) {
const auto vectorized_begin =
begin + (Vec::size() - begin % Vec::size()) % Vec::size();
const auto vectorized_end = end - (end % Vec::size());
// Scalar prologue.
for (const auto idx : c10::irange(begin, vectorized_begin)) {
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
for (const auto input_idx : c10::irange(kNumInputs)) {
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
}
data_out[idx] = std::apply(compute_fun, loaded_inputs);
}

// Main vectorized loop.
for (auto idx = vectorized_begin; idx < vectorized_end;
idx += Vec::size()) {
std::array<Vec, kNumInputs> loaded_vec_inputs;
for (const auto input_idx : c10::irange(kNumInputs)) {
loaded_vec_inputs[input_idx] =
Vec::loadu(&inputs_data_ptrs[input_idx][idx]);
}
auto result_vec = std::apply(compute_fun, loaded_vec_inputs);
result_vec.store(&data_out[idx]);
}

// Scalar epilogue.
for (const auto idx : c10::irange(vectorized_end, end)) {
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
for (const auto input_idx : c10::irange(kNumInputs)) {
loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
}
data_out[idx] = std::apply(compute_fun, loaded_inputs);
}
});
return;
}
}
#endif

::executorch::extension::parallel_for(
0,
out.numel(),
@@ -255,6 +330,17 @@ inline void apply_unitensor_elementwise_fn(
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
}

/**
* Useful for unary elementwise operators. For each element of the
* input, call Op and write to the corresponding element of the
* output. Tensor broadcasting is applied wherever it is required.
*
* [NOTE: Generic lambdas]: If Op is a *generic* lambda (i.e., one with `auto`
* parameters; normal lambdas are fine), it must fulfill one of the
* following conditions. Either:
* 1) It must in fact compile when passed at::vec::Vectorized<CTYPE_COMMON>, or
* 2) It must be actively SFINAE-friendly, as per the C++17 examples in https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable .
*/
template <
typename CTYPE_COMMON,
const char* op_name,
@@ -296,6 +382,7 @@ inline void apply_bitensor_elementwise_fn(
* Useful for bi-tensor elementwise operators. For each element of the inputs,
* perform a computation and write to the corresponding element of the output.
* Tensor broadcasting is applied wherever it is required.
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for compute_fun.
*/
template <
typename CTYPE_COMMON,
@@ -362,6 +449,8 @@ inline void apply_tritensor_elementwise_fn(
*
* static constexpr const char op_name[] = "my_op";
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
*
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for compute_fun.
*/
template <
typename CTYPE_COMMON,
1 change: 1 addition & 0 deletions kernels/portable/cpu/util/targets.bzl
Original file line number Diff line number Diff line change
@@ -110,6 +110,7 @@ def define_common_targets():
":broadcast_indexes_range",
":broadcast_util",
":dtype_util",
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
"//executorch/runtime/kernel:kernel_runtime_context",
"//executorch/runtime/kernel:thread_parallel_interface",
],
5 changes: 4 additions & 1 deletion runtime/core/portable_type/c10/c10/targets.bzl
Original file line number Diff line number Diff line change
@@ -49,7 +49,10 @@ def define_common_targets():
runtime.cxx_library(
name = "aten_headers_for_executorch",
srcs = [],
visibility = ["//executorch/kernels/optimized/..."],
visibility = [
"//executorch/kernels/optimized/...",
"//executorch/kernels/portable/cpu/util/...",
],
exported_deps = select({
"DEFAULT": [],
"ovr_config//cpu:arm64": [