Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use parallel_for_each_reduce_over_dim_list_output_index for {Map,}ReduceOverDimListPlan ops #9197

Merged
merged 147 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
d0b11e8
Update
swolchok Mar 4, 2025
9437be1
Update
swolchok Mar 4, 2025
643e10e
Update
swolchok Mar 4, 2025
6f2842b
Update
swolchok Mar 4, 2025
e47dfeb
Update
swolchok Mar 4, 2025
231ebc3
Update
swolchok Mar 5, 2025
296513c
Update
swolchok Mar 5, 2025
845a01e
Update
swolchok Mar 5, 2025
a92958a
Update
swolchok Mar 5, 2025
3fa99d6
Update
swolchok Mar 5, 2025
a6c69a6
Update
swolchok Mar 5, 2025
3bd6437
Update
swolchok Mar 5, 2025
675f01b
Update
swolchok Mar 5, 2025
5f3a768
Update
swolchok Mar 5, 2025
9fdebee
Update
swolchok Mar 5, 2025
70a7096
Update
swolchok Mar 5, 2025
337dc23
Update
swolchok Mar 5, 2025
f388177
Update
swolchok Mar 5, 2025
2949daf
Update
swolchok Mar 5, 2025
7347915
Update
swolchok Mar 5, 2025
1a8481d
Update
swolchok Mar 5, 2025
e48e816
Update
swolchok Mar 5, 2025
3351d50
Update
swolchok Mar 6, 2025
0102e25
Update
swolchok Mar 6, 2025
956f8a5
Update
swolchok Mar 6, 2025
9f7f0c1
Update
swolchok Mar 6, 2025
a1aeae7
Update
swolchok Mar 6, 2025
c658163
Update
swolchok Mar 6, 2025
7e0ccd4
Update
swolchok Mar 6, 2025
d9cd27c
Update
swolchok Mar 6, 2025
095d4f5
Update
swolchok Mar 6, 2025
e537a59
Update
swolchok Mar 6, 2025
c130224
Update
swolchok Mar 6, 2025
49c2971
Update
swolchok Mar 6, 2025
cd37460
Update
swolchok Mar 6, 2025
754a4f6
Update
swolchok Mar 6, 2025
11c5707
Update
swolchok Mar 6, 2025
7ca7627
Update
swolchok Mar 6, 2025
d428ca2
Update
swolchok Mar 6, 2025
62b6ef2
Update
swolchok Mar 6, 2025
b92ea35
Update
swolchok Mar 6, 2025
b478275
Update
swolchok Mar 7, 2025
0470870
Update
swolchok Mar 7, 2025
5a283c8
Update
swolchok Mar 7, 2025
a8a0e57
Update
swolchok Mar 7, 2025
df93cd4
Update
swolchok Mar 7, 2025
6350e07
Update
swolchok Mar 7, 2025
bd20770
Update
swolchok Mar 7, 2025
e7190a8
Update
swolchok Mar 7, 2025
4dd58a0
Update
swolchok Mar 7, 2025
1b6eb9f
Update
swolchok Mar 7, 2025
450e50b
Update
swolchok Mar 7, 2025
4459a7e
Update
swolchok Mar 7, 2025
fad4ed8
Update
swolchok Mar 7, 2025
085b624
Update
swolchok Mar 7, 2025
c7219a3
Update
swolchok Mar 7, 2025
e4af3bb
Update
swolchok Mar 8, 2025
9c9e31e
Update
swolchok Mar 8, 2025
34423ae
Update
swolchok Mar 8, 2025
379c10e
Update
swolchok Mar 8, 2025
fb5e06c
Update
swolchok Mar 8, 2025
4a0e893
Update
swolchok Mar 8, 2025
40a1bce
Update
swolchok Mar 8, 2025
37e4213
Update
swolchok Mar 8, 2025
ef1a0ce
Update
swolchok Mar 10, 2025
6844013
Update
swolchok Mar 10, 2025
e417a3b
Update
swolchok Mar 10, 2025
adaae97
Update
swolchok Mar 10, 2025
ea335ee
Update
swolchok Mar 10, 2025
2cc4910
Update
swolchok Mar 10, 2025
d7cdfa7
Update
swolchok Mar 10, 2025
98d6d01
Update
swolchok Mar 10, 2025
4a7ba26
Update
swolchok Mar 11, 2025
c0d1daa
Update
swolchok Mar 11, 2025
f408201
Update
swolchok Mar 11, 2025
e2fb689
Update
swolchok Mar 11, 2025
994c5f5
Update
swolchok Mar 11, 2025
b76240d
Update
swolchok Mar 11, 2025
18d5dde
Update
swolchok Mar 11, 2025
8cfdfa6
Update
swolchok Mar 11, 2025
4917358
Update
swolchok Mar 11, 2025
4a43b35
Update
swolchok Mar 11, 2025
3fe478d
Update
swolchok Mar 11, 2025
21d8aac
Update
swolchok Mar 11, 2025
2272c40
Update
swolchok Mar 11, 2025
f5bac6a
Update
swolchok Mar 11, 2025
c44fda6
Update
swolchok Mar 11, 2025
fa9ef9c
Update
swolchok Mar 11, 2025
73f37ee
Update
swolchok Mar 11, 2025
a8dd330
Update
swolchok Mar 11, 2025
0088cd2
Update
swolchok Mar 11, 2025
6b296df
Update
swolchok Mar 11, 2025
854c967
Update
swolchok Mar 11, 2025
03f00ee
Update
swolchok Mar 11, 2025
87085af
Update
swolchok Mar 11, 2025
2ee8846
Update
swolchok Mar 11, 2025
4779960
Update
swolchok Mar 11, 2025
e6be3fe
Update
swolchok Mar 11, 2025
930b2fd
Update
swolchok Mar 11, 2025
1ef9dd8
Update
swolchok Mar 11, 2025
c66f533
Update
swolchok Mar 11, 2025
e6d6ad6
Update
swolchok Mar 11, 2025
c0f0bec
Update
swolchok Mar 11, 2025
48e9452
Update
swolchok Mar 11, 2025
5dc8b27
Update
swolchok Mar 11, 2025
2dcb6db
Update
swolchok Mar 11, 2025
5781018
Update
swolchok Mar 11, 2025
caac9df
Update
swolchok Mar 11, 2025
66387af
Update
swolchok Mar 11, 2025
051f69c
Update
swolchok Mar 11, 2025
d8f4b13
Update
swolchok Mar 11, 2025
ecfabce
Update
swolchok Mar 11, 2025
a792450
Update
swolchok Mar 11, 2025
9fbe3a3
Update
swolchok Mar 11, 2025
f7cecde
Update
swolchok Mar 11, 2025
ba687e6
Update
swolchok Mar 11, 2025
cfbd318
Update
swolchok Mar 12, 2025
de50f9b
Update
swolchok Mar 12, 2025
d68ef30
Update
swolchok Mar 12, 2025
4ff9658
Update
swolchok Mar 12, 2025
2f094fc
Update
swolchok Mar 12, 2025
bec37a0
Update
swolchok Mar 12, 2025
723735f
Update
swolchok Mar 12, 2025
01f2790
Update
swolchok Mar 12, 2025
98e3147
Update
swolchok Mar 12, 2025
cb6aa4d
Update
swolchok Mar 12, 2025
eabae64
Update
swolchok Mar 12, 2025
046f5b6
Update
swolchok Mar 12, 2025
062908f
Update
swolchok Mar 12, 2025
18142d0
Update
swolchok Mar 12, 2025
f6c43ff
Update
swolchok Mar 12, 2025
cdcd351
Update
swolchok Mar 12, 2025
fc9caab
Update
swolchok Mar 12, 2025
d2d3f61
Update
swolchok Mar 12, 2025
824aebf
Update
swolchok Mar 12, 2025
939aabd
Update
swolchok Mar 12, 2025
2a4a905
Update
swolchok Mar 12, 2025
0c6246b
Update
swolchok Mar 13, 2025
605bfe7
Update
swolchok Mar 13, 2025
26da460
Update
swolchok Mar 13, 2025
9c307d5
Update
swolchok Mar 13, 2025
b82c9a8
Update
swolchok Mar 14, 2025
a2c5a30
Update
swolchok Mar 14, 2025
70a47ff
Update
swolchok Mar 14, 2025
f4e317c
Update
swolchok Mar 14, 2025
b734c85
Update
swolchok Mar 14, 2025
53b7998
Update
swolchok Mar 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions kernels/portable/cpu/op_amax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,17 @@ Tensor& amax_out(
ReduceOverDimListPlan plan(in, dim_list);
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amax.out", CTYPE, [&]() {
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
for (const auto out_ix : c10::irange(out.numel())) {
out_data[out_ix] = plan.execute<CTYPE>(
[](CTYPE v, CTYPE max_v) {
return std::isnan(v) || v > max_v ? v : max_v;
},
out_ix);
}
const bool success = parallel_for_each_reduce_over_dim_list_output_index(
in, dim_list, out, [&](const auto begin, const auto end) {
for (const auto out_ix : c10::irange(begin, end)) {
out_data[out_ix] = plan.execute<CTYPE>(
[](CTYPE v, CTYPE max_v) {
return std::isnan(v) || v > max_v ? v : max_v;
},
out_ix);
}
});
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
});

return out;
Expand Down
18 changes: 11 additions & 7 deletions kernels/portable/cpu/op_amin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,17 @@ Tensor& amin_out(
ReduceOverDimListPlan plan(in, dim_list);
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amin.out", CTYPE, [&]() {
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
for (const auto out_ix : c10::irange(out.numel())) {
out_data[out_ix] = plan.execute<CTYPE>(
[](CTYPE v, CTYPE min_v) {
return std::isnan(v) || v < min_v ? v : min_v;
},
out_ix);
}
const bool success = parallel_for_each_reduce_over_dim_list_output_index(
in, dim_list, out, [&](const auto begin, const auto end) {
for (const auto out_ix : c10::irange(begin, end)) {
out_data[out_ix] = plan.execute<CTYPE>(
[](CTYPE v, CTYPE min_v) {
return std::isnan(v) || v < min_v ? v : min_v;
},
out_ix);
}
});
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
});

return out;
Expand Down
25 changes: 15 additions & 10 deletions kernels/portable/cpu/op_any.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,21 @@ Tensor& any_dims_out(
static_cast<CTYPE_OUT>(static_cast<bool>(in_data[out_ix]));
}
} else {
for (const auto out_ix : c10::irange(out.numel())) {
bool any = false;
if (in_not_empty) {
any = plan->execute<CTYPE_IN, bool>(
[](CTYPE_IN v) { return static_cast<bool>(v); },
[](bool outv, bool acc) { return acc || outv; },
out_ix);
}
out_data[out_ix] = static_cast<CTYPE_OUT>(any);
}
const bool success =
parallel_for_each_reduce_over_dim_list_output_index(
in, dim_list, out, [&](const auto begin, const auto end) {
for (const auto out_ix : c10::irange(begin, end)) {
bool any = false;
if (in_not_empty) {
any = plan->execute<CTYPE_IN, bool>(
[](CTYPE_IN v) { return static_cast<bool>(v); },
[](bool outv, bool acc) { return acc || outv; },
out_ix);
}
out_data[out_ix] = static_cast<CTYPE_OUT>(any);
}
});
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
}
});
});
Expand Down
35 changes: 20 additions & 15 deletions kernels/portable/cpu/op_mean.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,27 @@ Tensor& mean_dim_out(
out);

MapReduceOverDimListPlan plan(in, dim_list);
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
ET_SWITCH_FLOATHBF16_TYPES(
out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
const size_t num = get_reduced_dim_product(in, dim_list);
for (const auto out_ix : c10::irange(out.numel())) {
CTYPE_OUT sum = 0;
if (in.numel() > 0) {
sum = plan.execute<CTYPE_IN, CTYPE_OUT>(
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
out_ix);
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "add.out";
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
const size_t num = get_reduced_dim_product(in, dim_list);
const bool success = parallel_for_each_reduce_over_dim_list_output_index(
in, dim_list, out, [&](const auto begin, const auto end) {
for (const auto out_ix : c10::irange(begin, end)) {
CTYPE_OUT sum = 0;
if (in.numel() > 0) {
sum = plan.execute<CTYPE_IN, CTYPE_OUT>(
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
out_ix);
}
out_data[out_ix] = sum / static_cast<float>(num);
}
out_data[out_ix] = sum / static_cast<float>(num);
}
});
});
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
});
});

return out;
Expand Down
36 changes: 20 additions & 16 deletions kernels/portable/cpu/op_sum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,27 @@ Tensor& sum_dim_out(
if (in.numel() > 0) {
plan.emplace(in, dim_list);
}
ET_SWITCH_REALHBBF16_TYPES(
in.scalar_type(), ctx, "sum.IntList_out", CTYPE_IN, [&] {
ET_SWITCH_REALHBBF16_TYPES(
out.scalar_type(), ctx, "sum.IntList_out", CTYPE_OUT, [&] {
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
for (const auto out_ix : c10::irange(out.numel())) {
CTYPE_OUT sum = 0;
if (plan.has_value()) {
sum = plan->execute<CTYPE_IN, CTYPE_OUT>(
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
out_ix);
}
out_data[out_ix] = sum;
// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "sum.IntList_out";
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
ET_SWITCH_REALHBBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
const bool success = parallel_for_each_reduce_over_dim_list_output_index(
in, dim_list, out, [&](const auto begin, const auto end) {
for (const auto out_ix : c10::irange(begin, end)) {
CTYPE_OUT sum = 0;
if (plan.has_value()) {
sum = plan->execute<CTYPE_IN, CTYPE_OUT>(
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
out_ix);
}
});
});
out_data[out_ix] = sum;
}
});
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
});
});

return out;
}
Expand Down
41 changes: 23 additions & 18 deletions kernels/portable/cpu/op_var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace {

template <typename CTYPE_IN, typename CTYPE_OUT>
void compute_variance(
KernelRuntimeContext& ctx,
const Tensor& in,
Tensor& out,
optional<ArrayRef<int64_t>> dim_list,
Expand All @@ -33,22 +34,26 @@ void compute_variance(
}
} else {
MapReduceOverDimListPlan plan(in, dim_list);
for (const auto out_ix : c10::irange(out.numel())) {
CTYPE_OUT sum = plan.execute<CTYPE_IN, CTYPE_OUT>(
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
out_ix);
CTYPE_OUT mean = sum / static_cast<CTYPE_OUT>(num);
CTYPE_OUT sum2 = plan.execute<CTYPE_IN, CTYPE_OUT>(
[mean](CTYPE_IN v) {
return (
(static_cast<CTYPE_OUT>(v) - mean) *
(static_cast<CTYPE_OUT>(v) - mean));
},
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
out_ix);
out_data[out_ix] = sum2 / denominator;
}
const bool success = parallel_for_each_reduce_over_dim_list_output_index(
in, dim_list, out, [&](const auto begin, const auto end) {
for (const auto out_ix : c10::irange(begin, end)) {
CTYPE_OUT sum = plan.execute<CTYPE_IN, CTYPE_OUT>(
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
out_ix);
CTYPE_OUT mean = sum / static_cast<CTYPE_OUT>(num);
CTYPE_OUT sum2 = plan.execute<CTYPE_IN, CTYPE_OUT>(
[mean](CTYPE_IN v) {
return (
(static_cast<CTYPE_OUT>(v) - mean) *
(static_cast<CTYPE_OUT>(v) - mean));
},
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
out_ix);
out_data[out_ix] = sum2 / denominator;
}
});
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
}
}

Expand Down Expand Up @@ -90,7 +95,7 @@ Tensor& var_out(

ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, name, CTYPE_IN, [&] {
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, name, CTYPE_OUT, [&] {
compute_variance<CTYPE_IN, CTYPE_OUT>(in, out, dim_list, num, denom);
compute_variance<CTYPE_IN, CTYPE_OUT>(ctx, in, out, dim_list, num, denom);
});
});

Expand Down Expand Up @@ -135,7 +140,7 @@ Tensor& var_correction_out(

ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, name, CTYPE_IN, [&] {
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, name, CTYPE_OUT, [&] {
compute_variance<CTYPE_IN, CTYPE_OUT>(in, out, dim_list, num, denom);
compute_variance<CTYPE_IN, CTYPE_OUT>(ctx, in, out, dim_list, num, denom);
});
});

Expand Down
8 changes: 8 additions & 0 deletions kernels/portable/cpu/util/reduce_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -823,11 +823,15 @@ template <typename Func>
executorch::aten::optional<int64_t> dim,
const Tensor& out,
const Func& func) {
#ifdef ET_USE_THREADPOOL
const ssize_t reduction_size = get_reduced_dim_product(in, dim);
const auto grain_size = std::max(
static_cast<ssize_t>(1),
static_cast<ssize_t>(executorch::extension::internal::GRAIN_SIZE) /
reduction_size);
#else // ET_USE_THREADPOOL
const auto grain_size = 1;
#endif // ET_USE_THREADPOOL
return executorch::extension::parallel_for(0, out.numel(), grain_size, func);
}

Expand All @@ -842,11 +846,15 @@ template <typename Func>
optional<ArrayRef<int64_t>> dim_list,
const Tensor& out,
const Func& func) {
#ifdef ET_UE_THREADPOOL
const ssize_t reduction_size = get_reduced_dim_product(in, dim_list);
const auto grain_size = std::max(
static_cast<ssize_t>(1),
static_cast<ssize_t>(executorch::extension::internal::GRAIN_SIZE) /
reduction_size);
#else // ET_USE_THREADPOOL
const auto grain_size = 1;
#endif // ET_USE_THREADPOOL
return executorch::extension::parallel_for(0, out.numel(), grain_size, func);
}

Expand Down
Loading