From 9fcd8857fb0e00bee0b401f5e25f1fd081fe3c9c Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 18 Mar 2025 17:32:12 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- kernels/portable/cpu/util/dtype_util.h | 11 ---------- kernels/portable/cpu/util/elementwise_util.h | 23 ++++++++++++++++---- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/kernels/portable/cpu/util/dtype_util.h b/kernels/portable/cpu/util/dtype_util.h index 2bbd5de4577..59b82cdc51b 100644 --- a/kernels/portable/cpu/util/dtype_util.h +++ b/kernels/portable/cpu/util/dtype_util.h @@ -86,12 +86,6 @@ load_to_common_fn get_load_to_common_fn_bool_or_byte( template load_to_common_fn get_load_to_common_fn_same_as_compute( const Tensor& t) { - constexpr auto common_scalar_type = CppTypeToScalarType::value; - ET_CHECK_MSG( - t.scalar_type() == common_scalar_type, - "Unhandled dtype %s for %s", - ::executorch::runtime::toString(common_scalar_type), - op_name); return internal::load_and_convert; } @@ -180,11 +174,6 @@ template store_common_to_tensor_fn get_store_common_to_tensor_fn_same_as_compute(const Tensor& t) { constexpr auto common_scalar_type = CppTypeToScalarType::value; - ET_CHECK_MSG( - t.scalar_type() == common_scalar_type, - "Unhandled dtype %s for %s", - ::executorch::runtime::toString(common_scalar_type), - op_name); return internal::convert_and_store; } diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index f5932069005..021ec42bf27 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -51,6 +51,13 @@ inline int64_t scalar_to(const Scalar& s) { } namespace internal { +template +using ignore_first_yield_second = T; + +template +using op_call_result = + std::invoke_result_t...>; + template < typename CTYPE_COMMON, const char* op_name, @@ -89,9 +96,16 @@ inline void apply_elementwise_fn( inputs.first->element_size(), })...}; - const auto store_common_to_out = - internal::get_store_common_to_tensor_fn( - out, out_dtypes); + // NOTE: the result of compute_fun is not necessarily CTYPE_COMMON! + // For example, consider the possibility that compute_fun is a + // trigonometric function like acos, the common input type is bool, + // and the output type is float -- we would truncate acos(0) ~= 1.67 + // to just 1. Conveniently, it costs us nothing at runtime to handle + // this correctly. + const auto store_compute_result_to_out = + internal::get_store_common_to_tensor_fn< + op_call_result, + op_name>(out, out_dtypes); char* const data_out = reinterpret_cast(out.mutable_data_ptr()); const auto out_element_size = out.element_size(); @@ -114,7 +128,8 @@ inline void apply_elementwise_fn( .data_ptr[indexes[idx + 1] * input_info.element_size]); } auto result = std::apply(compute_fun, loaded_inputs); - store_common_to_out(result, &data_out[indexes[0] * out_element_size]); + store_compute_result_to_out( + result, &data_out[indexes[0] * out_element_size]); } }); } From 40c1b1be46d2ad91f6ca39fe3008d9b685d3f45b Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 19 Mar 2025 09:58:10 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- kernels/portable/cpu/util/dtype_util.h | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/kernels/portable/cpu/util/dtype_util.h b/kernels/portable/cpu/util/dtype_util.h index 76579301850..1f0e3403e82 100644 --- a/kernels/portable/cpu/util/dtype_util.h +++ b/kernels/portable/cpu/util/dtype_util.h @@ -173,27 +173,13 @@ get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) { template store_common_to_tensor_fn get_store_common_to_tensor_fn_same_as_compute(const Tensor& t) { - return internal::convert_and_store; + // We already validate tensor types earlier in the process, so at + // this phase, treat same_as_compute the same as our widest + // SupportedTensorDtypes set. + return get_store_common_to_tensor_fn_realhbf16(t); } -template < - typename CTYPE_COMMON, - const char* op_name, - std::enable_if_t, bool> = true> -store_common_to_tensor_fn -get_store_common_to_tensor_fn_same_as_common(const Tensor& t) { - void (*result)(CTYPE_COMMON, void*) = nullptr; - ET_SWITCH_THREE_TYPES( - Float, Half, BFloat16, t.scalar_type(), unused, op_name, CTYPE, [&]() { - result = internal::convert_and_store; - }); - return result; -} - -template < - typename CTYPE_COMMON, - const char* op_name, - std::enable_if_t, bool> = true> +template store_common_to_tensor_fn get_store_common_to_tensor_fn_same_as_common(const Tensor& t) { return get_store_common_to_tensor_fn_same_as_compute(