Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

elementwise_util: don't cast the result of compute_fun back to the common type #9385

Open
wants to merge 4 commits into
base: gh/swolchok/379/head
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 5 additions & 31 deletions kernels/portable/cpu/util/dtype_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,6 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte(
template <typename CTYPE_COMMON, const char* op_name>
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_compute(
const Tensor& t) {
constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
ET_CHECK_MSG(
t.scalar_type() == common_scalar_type,
"Unhandled dtype %s for %s",
::executorch::runtime::toString(common_scalar_type),
op_name);
return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON>;
}

Expand Down Expand Up @@ -179,33 +173,13 @@ get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
template <typename CTYPE_COMMON, const char* op_name>
store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_same_as_compute(const Tensor& t) {
constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
ET_CHECK_MSG(
t.scalar_type() == common_scalar_type,
"Unhandled dtype %s for %s",
::executorch::runtime::toString(common_scalar_type),
op_name);
return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
// We already validate tensor types earlier in the process, so at
// this phase, treat same_as_compute the same as our widest
// SupportedTensorDtypes set.
return get_store_common_to_tensor_fn_realhbf16<CTYPE_COMMON, op_name>(t);
}

template <
typename CTYPE_COMMON,
const char* op_name,
std::enable_if_t<std::is_same_v<CTYPE_COMMON, float>, bool> = true>
store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
void (*result)(CTYPE_COMMON, void*) = nullptr;
ET_SWITCH_THREE_TYPES(
Float, Half, BFloat16, t.scalar_type(), unused, op_name, CTYPE, [&]() {
result = internal::convert_and_store<CTYPE, CTYPE_COMMON>;
});
return result;
}

template <
typename CTYPE_COMMON,
const char* op_name,
std::enable_if_t<!std::is_same_v<CTYPE_COMMON, float>, bool> = true>
template <typename CTYPE_COMMON, const char* op_name>
store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
return get_store_common_to_tensor_fn_same_as_compute<CTYPE_COMMON, op_name>(
Expand Down
23 changes: 19 additions & 4 deletions kernels/portable/cpu/util/elementwise_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
}

namespace internal {
template <typename Ignore, typename T>
using ignore_first_yield_second = T;

template <typename CTYPE_COMMON, typename Op, typename... Args>
using op_call_result =
std::invoke_result_t<Op, ignore_first_yield_second<Args, CTYPE_COMMON>...>;

template <
typename CTYPE_COMMON,
const char* op_name,
Expand Down Expand Up @@ -89,9 +96,16 @@ inline void apply_elementwise_fn(
inputs.first->element_size(),
})...};

const auto store_common_to_out =
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
out, out_dtypes);
// NOTE: the result of compute_fun is not necessarily CTYPE_COMMON!
// For example, consider the possibility that compute_fun is a
// trigonometric function like acos, the common input type is bool,
// and the output type is float -- we would truncate acos(0) ~= 1.67
// to just 1. Conveniently, it costs us nothing at runtime to handle
// this correctly.
const auto store_compute_result_to_out =
internal::get_store_common_to_tensor_fn<
op_call_result<CTYPE_COMMON, Op, Args...>,
op_name>(out, out_dtypes);
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
const auto out_element_size = out.element_size();

Expand All @@ -114,7 +128,8 @@ inline void apply_elementwise_fn(
.data_ptr[indexes[idx + 1] * input_info.element_size]);
}
auto result = std::apply(compute_fun, loaded_inputs);
store_common_to_out(result, &data_out[indexes[0] * out_element_size]);
store_compute_result_to_out(
result, &data_out[indexes[0] * out_element_size]);
}
});
}
Expand Down
Loading