Skip to content

Commit

Permalink
Use tensor_shape_to_c_string for error in check_mask_indices
Browse files Browse the repository at this point in the history
Rolling out for #7902

ghstack-source-id: db14c70dfafc167611004d7842325d3f6e0ebbde
ghstack-comment-id: 2643854240
Pull Request resolved: #8314
  • Loading branch information
swolchok committed Feb 7, 2025
1 parent 883d33a commit e96a6b7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
20 changes: 17 additions & 3 deletions kernels/portable/cpu/util/advanced_index_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/

#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/runtime/core/exec_aten/util/tensor_shape_to_c_string.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
Expand Down Expand Up @@ -49,9 +50,22 @@ bool check_mask_indices(const Tensor& in, TensorOptList indices) {
ET_LOG_MSG_AND_RETURN_IF_FALSE(
index.dim() > 0, "Zero-dimensional mask index not allowed");
for (auto j = 0; j < index.dim(); j++) {
ET_LOG_MSG_AND_RETURN_IF_FALSE(
index.size(j) == in.size(in_i + j),
"The shape of mask index must match the sizes of the corresponding input dimensions.");
if (index.size(j) != in.size(in_i + j)) {
#ifdef ET_LOG_ENABLED
auto mask_shape = executorch::runtime::tensor_shape_to_c_string(
executorch::runtime::Span<const Tensor::SizesType>(
index.sizes().data(), index.sizes().size()));
auto input_shape = executorch::runtime::tensor_shape_to_c_string(
executorch::runtime::Span<const Tensor::SizesType>(
in.sizes().data() + in_i, index.sizes().size()));
ET_LOG(
Error,
"The shape of mask index %s must match the sizes of the corresponding input dimensions %s.",
mask_shape.data(),
input_shape.data());
#endif // ET_LOG_ENABLED
return false;
}
}
in_i += index.dim();
} else {
Expand Down
1 change: 1 addition & 0 deletions kernels/portable/cpu/util/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def define_common_targets():
compiler_flags = ["-Wno-missing-prototypes"],
deps = [
":broadcast_util",
"//executorch/runtime/core/exec_aten/util:tensor_shape_to_c_string",
"//executorch/runtime/kernel:kernel_includes",
],
visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."],
Expand Down

0 comments on commit e96a6b7

Please sign in to comment.