Skip to content

Commit

Permalink
[ExecuTorch] Add broadcasting support to optimized op_div
Browse files Browse the repository at this point in the history
Summary:
Similar to broadcast support in op_mul

Test Plan:
Tests added

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c9d05c8405401e10a45a08c3e160aa04eaf1be86
Pull Request resolved: #8257
  • Loading branch information
kimishpatel committed Feb 7, 2025
1 parent 22178b2 commit 3262edc
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 39 deletions.
65 changes: 26 additions & 39 deletions kernels/optimized/cpu/op_div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,48 +120,35 @@ Tensor& opt_div_out(
out.numel());
});
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
const Tensor* lhs;
const Tensor* rhs;
// Reason for using alpha is becasuse handle_broadcast_elementwise
// is used for add and sub as well:
static constexpr const char op_name[] = "mul.out";
if (selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
lhs = &b;
rhs = &a;
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
// This behavior is a bit confusing.
// Reason we swap out args here is because handle_broadcast_elementwise
// handles this selected_optimized_path option a bit differently.
// This should really be resoled in handle_broadcast_elementwise.
// However, the current blocker is that handle_broadcast_elementwise tries
// to be agnostic of op. This should be fixed, likely by moving lambda
// creation to handle_broadcast_elementwise and it be aware of which op is
// being executed.
auto div_lambda = [](auto x, auto y, [[maybe_unused]] auto alpha) {
return y / x;
};
return torch::executor::handle_broadcast_elementwise<op_name>(
ctx, div_lambda, a, b, out, selected_optimized_path);
} else {
// Catch failure to update logic when subing new broadcasting possibility.
ET_DCHECK(
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1d);
lhs = &a;
rhs = &b;
auto div_lambda = [](auto x, auto y, [[maybe_unused]] auto alpha) {
return x / y;
};
return torch::executor::handle_broadcast_elementwise<op_name>(
ctx, div_lambda, a, b, out, selected_optimized_path);
}
auto error = resize_tensor(out, lhs->sizes());
ET_KERNEL_CHECK_MSG(
ctx,
error == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");
ET_SWITCH_REALB_TYPES(out_type, ctx, "sub.out", CTYPE, [&]() {
using Vec = executorch::vec::Vectorized<CTYPE>;
if (selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
[](Vec x, Vec y) { return y / x; },
out.mutable_data_ptr<CTYPE>(),
lhs->const_data_ptr<CTYPE>(),
rhs->const_data_ptr<CTYPE>(),
lhs->sizes()[lhs->dim() - 2],
lhs->sizes()[lhs->dim() - 1]);
} else {
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
[](Vec x, Vec y) { return x / y; },
out.mutable_data_ptr<CTYPE>(),
lhs->const_data_ptr<CTYPE>(),
rhs->const_data_ptr<CTYPE>(),
lhs->sizes()[lhs->dim() - 2],
lhs->sizes()[lhs->dim() - 1]);
}
});
} else {
ScalarType common_type = get_compute_type(a_type, b_type);
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
Expand Down
54 changes: 54 additions & 0 deletions kernels/test/op_div_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,52 @@ class OpDivOutTest : public OperatorTest {
ET_EXPECT_KERNEL_FAILURE(context_, op_div_out(a, b, out));
}

template <ScalarType DTYPE>
void test_broadcast_3D() {
TensorFactory<DTYPE> tf_a;

Tensor a =
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
Tensor b = tf_a.make({2, 1, 3}, /*data=*/{2, 3, 4, 5, 6, 7});

// Destination for output of mul.
Tensor out =
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
Tensor expected = tf_a.make(
{2, 2, 3},
/*data=*/
{0.5000,
0.6667,
0.75002,
2.0000,
1.6667,
1.5000,
1.4000,
1.3333,
1.2857,
2.0000,
1.8333,
1.7143});
// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE_WITH_TOL(op_div_out(a, b, out), expected, 1e-4, 1e-4);
expected = tf_a.make(
{2, 2, 3},
/*data=*/
{2.0000,
1.5000,
1.3333,
0.5000,
0.6000,
0.6667,
0.7143,
0.7500,
0.7778,
0.5000,
0.5455,
0.5833});
EXPECT_TENSOR_CLOSE_WITH_TOL(op_div_out(b, a, out), expected, 1e-4, 1e-4);
}

/**
* Common testing for div operator, for float output types
*/
Expand Down Expand Up @@ -457,6 +503,14 @@ TEST_F(OpDivOutTest, DynamicShapeUpperBoundLargerThanExpected) {
EXPECT_TENSOR_CLOSE(out, expected_result);
}

TEST_F(OpDivOutTest, BroadcastNDTest) {
// Test 3D tensors
test_broadcast_3D<ScalarType::Float>();
// half and bfloat16 are not supported for div quite yet
// test_broadcast_3D<ScalarType::Half>();
// test_broadcast_3D<ScalarType::BFloat16>();
}

TEST_F(OpDivOutTest, DynamicShapeUnbound) {
GTEST_SKIP() << "Dynamic shape not supported";
TensorFactory<ScalarType::Float> tf;
Expand Down

0 comments on commit 3262edc

Please sign in to comment.