diff --git a/kernels/optimized/cpu/op_div.cpp b/kernels/optimized/cpu/op_div.cpp index 4d7b8efe9e..e630f1c03b 100644 --- a/kernels/optimized/cpu/op_div.cpp +++ b/kernels/optimized/cpu/op_div.cpp @@ -120,46 +120,22 @@ Tensor& opt_div_out( out.numel()); }); } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { - const Tensor* lhs; - const Tensor* rhs; - if (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) { - lhs = &b; - rhs = &a; - } else { - // Catch failure to update logic when subing new broadcasting possibility. - ET_DCHECK( - selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1d); - lhs = &a; - rhs = &b; - } - auto error = resize_tensor(out, lhs->sizes()); - ET_KERNEL_CHECK_MSG( - ctx, - error == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - ET_SWITCH_REALB_TYPES(out_type, ctx, "sub.out", CTYPE, [&]() { - using Vec = executorch::vec::Vectorized; + // Reason for using alpha is becasuse handle_broadcast_elementwise + // is used for add and sub as well: + ET_SWITCH_REALB_TYPES(out_type, ctx, "div.out", CTYPE, [&]() { if (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) { - executorch::vec::broadcasting_map_2d_by_1d( - [](Vec x, Vec y) { return y / x; }, - out.mutable_data_ptr(), - lhs->const_data_ptr(), - rhs->const_data_ptr(), - lhs->sizes()[lhs->dim() - 2], - lhs->sizes()[lhs->dim() - 1]); + ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments || + selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments || + selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) { + auto div_lambda = [](auto x, auto y) { return y / x; }; + return torch::executor::handle_broadcast_elementwise( + ctx, div_lambda, a, b, out, selected_optimized_path); } else { - executorch::vec::broadcasting_map_2d_by_1d( - [](Vec x, Vec y) { return x / y; }, - out.mutable_data_ptr(), - lhs->const_data_ptr(), - rhs->const_data_ptr(), - lhs->sizes()[lhs->dim() - 2], - lhs->sizes()[lhs->dim() - 1]); + auto div_lambda = [](auto x, auto y) { return x / y; }; + return torch::executor::handle_broadcast_elementwise( + ctx, div_lambda, a, b, out, selected_optimized_path); } }); } else { diff --git a/kernels/test/op_div_test.cpp b/kernels/test/op_div_test.cpp index 97d538971c..8f41419a8e 100644 --- a/kernels/test/op_div_test.cpp +++ b/kernels/test/op_div_test.cpp @@ -83,6 +83,52 @@ class OpDivOutTest : public OperatorTest { ET_EXPECT_KERNEL_FAILURE(context_, op_div_out(a, b, out)); } + template + void test_broadcast_3D() { + TensorFactory tf_a; + + Tensor a = + tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Tensor b = tf_a.make({2, 1, 3}, /*data=*/{2, 3, 4, 5, 6, 7}); + + // Destination for output of mul. + Tensor out = + tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Tensor expected = tf_a.make( + {2, 2, 3}, + /*data=*/ + {0.5000, + 0.6667, + 0.75002, + 2.0000, + 1.6667, + 1.5000, + 1.4000, + 1.3333, + 1.2857, + 2.0000, + 1.8333, + 1.7143}); + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE_WITH_TOL(op_div_out(a, b, out), expected, 1e-4, 1e-4); + expected = tf_a.make( + {2, 2, 3}, + /*data=*/ + {2.0000, + 1.5000, + 1.3333, + 0.5000, + 0.6000, + 0.6667, + 0.7143, + 0.7500, + 0.7778, + 0.5000, + 0.5455, + 0.5833}); + EXPECT_TENSOR_CLOSE_WITH_TOL(op_div_out(b, a, out), expected, 1e-4, 1e-4); + } + /** * Common testing for div operator, for float output types */ @@ -457,6 +503,14 @@ TEST_F(OpDivOutTest, DynamicShapeUpperBoundLargerThanExpected) { EXPECT_TENSOR_CLOSE(out, expected_result); } +TEST_F(OpDivOutTest, BroadcastNDTest) { + // Test 3D tensors + test_broadcast_3D(); + // half and bfloat16 are not supported for div quite yet + // test_broadcast_3D(); + // test_broadcast_3D(); +} + TEST_F(OpDivOutTest, DynamicShapeUnbound) { GTEST_SKIP() << "Dynamic shape not supported"; TensorFactory tf;