Skip to content

Commit

Permalink
[Executorch] Add broadcasting support to optimized op_sub
Browse files Browse the repository at this point in the history
Summary:
This diff builds on top of previous one to add support for limited
handling of broadcasting for sub

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 25520a88d2b79278395c105183b70341bc004e5e
Pull Request resolved: #8256
  • Loading branch information
kimishpatel committed Feb 7, 2025
1 parent a499572 commit 22178b2
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 104 deletions.
109 changes: 5 additions & 104 deletions kernels/optimized/cpu/op_sub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>

#include <executorch/kernels/optimized/cpu/op_add_sub_impl.h>

namespace torch {
namespace executor {
namespace native {
Expand Down Expand Up @@ -138,110 +140,9 @@ Tensor& opt_sub_out(
}
}

auto selected_optimized_path = select_optimized_path(a, b, out);
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
// Resize for dynamic shape
auto error = resize_tensor(out, a.sizes());
ET_KERNEL_CHECK_MSG(
ctx,
error == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");

ET_SWITCH_REAL_TYPES(a_type, ctx, "sub.out", CTYPE, [&]() {
CTYPE alpha_val;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );

using Vec = executorch::vec::Vectorized<CTYPE>;
executorch::vec::map2<CTYPE>(
[alpha_val](Vec x, Vec y) { return x - Vec(alpha_val) * y; },
out.mutable_data_ptr<CTYPE>(),
a.const_data_ptr<CTYPE>(),
b.const_data_ptr<CTYPE>(),
out.numel());
});
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
const Tensor* lhs;
const Tensor* rhs;
if (selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
lhs = &b;
rhs = &a;
} else {
// Catch failure to update logic when subing new broadcasting possibility.
ET_DCHECK(
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1d);
lhs = &a;
rhs = &b;
}
auto error = resize_tensor(out, lhs->sizes());
ET_KERNEL_CHECK_MSG(
ctx,
error == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");
ET_SWITCH_REAL_TYPES(out_type, ctx, "sub.out", CTYPE, [&]() {
CTYPE alpha_val;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );

using Vec = executorch::vec::Vectorized<CTYPE>;
if (selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
[alpha_val](Vec x, Vec y) { return y - Vec(alpha_val) * x; },
out.mutable_data_ptr<CTYPE>(),
lhs->const_data_ptr<CTYPE>(),
rhs->const_data_ptr<CTYPE>(),
lhs->sizes()[lhs->dim() - 2],
lhs->sizes()[lhs->dim() - 1]);
} else {
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
[alpha_val](Vec x, Vec y) { return x - Vec(alpha_val) * y; },
out.mutable_data_ptr<CTYPE>(),
lhs->const_data_ptr<CTYPE>(),
rhs->const_data_ptr<CTYPE>(),
lhs->sizes()[lhs->dim() - 2],
lhs->sizes()[lhs->dim() - 1]);
}
});
} else {
ScalarType common_type =
promoteTypes(a_type, b_type, /*half_to_float*/ true);
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);

ET_KERNEL_CHECK(
ctx,
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
InvalidArgument,
out);

ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.out", CTYPE_A, [&]() {
ET_SWITCH_REALH_TYPES(b_type, ctx, "sub.out", CTYPE_B, [&]() {
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REALH_TYPES(out_type, ctx, "sub.out", CTYPE_OUT, [&]() {
CTYPE_IN alpha_val;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );

SubInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, alpha_val, out);
});
});
});
}

return out;
static constexpr const char op_name[] = "sub.out";
return torch::executor::kernels::impl::opt_add_sub_out_impl<true, op_name>(
ctx, a, b, alpha, out);
}

Tensor& opt_sub_scalar_out(
Expand Down
1 change: 1 addition & 0 deletions kernels/optimized/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ _OPTIMIZED_ATEN_OPS = (
name = "op_sub",
deps = [
":binary_ops",
":add_sub_impl",
"//executorch/kernels/portable/cpu:scalar_utils",
"//executorch/kernels/portable/cpu/util:broadcast_util",
],
Expand Down
116 changes: 116 additions & 0 deletions kernels/test/op_sub_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,109 @@ class OpSubOutTest : public OperatorTest {
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{0.1, 1.2, 3.4, 7.8}));
}

template <ScalarType DTYPE>
void test_broadcast_3D() {
TensorFactory<DTYPE> tf_a;

Tensor a =
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
Tensor b = tf_a.make({2, 1, 3}, /*data=*/{2, 3, 4, 5, 6, 7});

// Destination for output of mul.
Tensor out =
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
Tensor expected =
tf_a.make({2, 2, 3}, /*data=*/{-1, -1, -1, 2, 2, 2, 2, 2, 2, 5, 5, 5});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(op_sub_out(a, b, 1.0, out), expected);
// b - a * 1.5 output should be
expected = tf_a.make(
{2, 2, 3},
/*data=*/
{0.5,
0.0,
-0.5,
-4.0,
-4.5,
-5.0,
-5.5,
-6.0,
-6.5,
-10.0,
-10.5,
-11.0});
EXPECT_TENSOR_CLOSE(op_sub_out(b, a, 1.5, out), expected);
}

template <ScalarType DTYPE>
void test_broadcast_4D() {
TensorFactory<DTYPE> tf_a;

Tensor a = tf_a.make(
{2, 2, 3, 5},
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60});
Tensor b = tf_a.make(
{2, 1, 3, 5},
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});

// Destination for output of mul.
Tensor out = tf_a.zeros({2, 2, 3, 5});
Tensor expected = tf_a.make(
{2, 2, 3, 5},
/*data=*/{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(op_sub_out(a, b, 1.0, out), expected);
expected = tf_a.make(
{2, 2, 3, 5},
/*data=*/{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, -15, -15, -15, -15, -15, -15, -15, -15, -15,
-15, -15, -15, -15, -15, -15, -15, -15, -15, -15, -15, -15,
-15, -15, -15, -15, -15, -15, -15, -15, -15, -30, -30, -30,
-30, -30, -30, -30, -30, -30, -30, -30, -30, -30, -30, -30});
EXPECT_TENSOR_CLOSE(op_sub_out(b, a, 1.0, out), expected);

b = tf_a.make(
{2, 2, 1, 5}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19, 20});
out = tf_a.zeros({2, 2, 3, 5});
expected = tf_a.make(
{2, 2, 3, 5},
/*data=*/{0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 15, 15, 15, 15, 15, 20, 20, 20, 20, 20,
20, 20, 20, 20, 20, 25, 25, 25, 25, 25, 30, 30, 30, 30, 30,
30, 30, 30, 30, 30, 35, 35, 35, 35, 35, 40, 40, 40, 40, 40});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(op_sub_out(a, b, 1.0, out), expected);
expected = tf_a.make(
{2, 2, 3, 5},
/*data=*/{-0.5000, -1.0000, -1.5000, -2.0000, -2.5000,
-8.0000, -8.5000, -9.0000, -9.5000, -10.0000,
-15.5000, -16.0000, -16.5000, -17.0000, -17.5000,

-18.0000, -18.5000, -19.0000, -19.5000, -20.0000,
-25.5000, -26.0000, -26.5000, -27.0000, -27.5000,
-33.0000, -33.5000, -34.0000, -34.5000, -35.0000,

-35.5000, -36.0000, -36.5000, -37.0000, -37.5000,
-43.0000, -43.5000, -44.0000, -44.5000, -45.0000,
-50.5000, -51.0000, -51.5000, -52.0000, -52.5000,

-53.0000, -53.5000, -54.0000, -54.5000, -55.0000,
-60.5000, -61.0000, -61.5000, -62.0000, -62.5000,
-68.0000, -68.5000, -69.0000, -69.5000, -70.0000});
EXPECT_TENSOR_CLOSE(op_sub_out(b, a, 1.5, out), expected);
}

void test_sub_enumerate_a_types() {
#define ENUMERATE_TEST_ENTRY(ctype, dtype) \
test_sub_enumerate_b_types<ScalarType::dtype>();
Expand Down Expand Up @@ -237,6 +340,19 @@ TEST_F(OpSubOutTest, BroadcastScalarRank0Supported) {
EXPECT_TENSOR_EQ(out, ret);
}

TEST_F(OpSubOutTest, BroadcastNDTest) {
// Test 3D tensors
test_broadcast_3D<ScalarType::Float>();
test_broadcast_3D<ScalarType::Half>();
// Sub doesnt yet support BFloat16
// test_broadcast_3D<ScalarType::BFloat16>();

// Test 4D tensors
test_broadcast_4D<ScalarType::Float>();
test_broadcast_4D<ScalarType::Half>();
// test_broadcast_4D<ScalarType::BFloat16>();
}

//
// Death Tests
//
Expand Down

0 comments on commit 22178b2

Please sign in to comment.