Skip to content

Commit b195ed9

Browse files
authored
Use parallel_for in functional_util's apply_unary_map_fun (#9348)
The other ones are reductions. More #8932 rollout.
1 parent a28e9be commit b195ed9

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

kernels/portable/cpu/util/functional_util.h

+10-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1414
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
15+
#include <executorch/runtime/kernel/thread_parallel_interface.h>
1516

1617
namespace torch {
1718
namespace executor {
@@ -53,9 +54,15 @@ inline void apply_unary_map_fn(
5354
CTYPE_OUT* const data_out,
5455
const int64_t size,
5556
const int64_t stride = 1) {
56-
for (const auto i : c10::irange(size)) {
57-
data_out[i * stride] = map_fun(data_in[i * stride]);
58-
}
57+
executorch::extension::parallel_for(
58+
0,
59+
size,
60+
::executorch::extension::internal::GRAIN_SIZE,
61+
[&](const auto begin, const auto end) {
62+
for (const auto i : c10::irange(begin, end)) {
63+
data_out[i * stride] = map_fun(data_in[i * stride]);
64+
}
65+
});
5966
}
6067

6168
//

kernels/portable/cpu/util/targets.bzl

+3
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,9 @@ def define_common_targets():
244244
name = "functional_util",
245245
srcs = [],
246246
exported_headers = ["functional_util.h"],
247+
exported_deps = [
248+
"//executorch/runtime/kernel:thread_parallel_interface",
249+
],
247250
deps = [
248251
"//executorch/runtime/kernel:kernel_includes",
249252
"//executorch/runtime/core/exec_aten/util:tensor_util",

0 commit comments

Comments
 (0)