Skip to content

Commit 2b4f903

Browse files
Implement batch_bool::mask() for riscv
As a followup to #1236
1 parent af74418 commit 2b4f903

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

include/xsimd/arch/xsimd_rvv.hpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <type_traits>
1515
#include <utility>
1616

17+
#include "../types/xsimd_batch_constant.hpp"
1718
#include "../types/xsimd_rvv_register.hpp"
1819
#include "xsimd_constants.hpp"
1920

@@ -1504,6 +1505,59 @@ namespace xsimd
15041505
const auto mask = abs(arg) < constants::maxflint<batch<T, A>>();
15051506
return select(mask, to_float(detail::rvvfcvt_default(arg)), arg, rvv {});
15061507
}
1508+
1509+
// mask
1510+
template <class A, class T>
1511+
XSIMD_INLINE uint64_t mask(batch_bool<T, A> const& self, requires_arch<common>) noexcept;
1512+
1513+
template <class A, class T>
1514+
XSIMD_INLINE uint64_t mask(batch_bool<T, A> const& self, requires_arch<rvv>) noexcept
1515+
{
1516+
XSIMD_IF_CONSTEXPR((8 * sizeof(T)) >= batch_bool<T, A>::size)
1517+
{
1518+
// (A) Easy case: the number of slots fits in T.
1519+
const auto zero = detail::broadcast<as_unsigned_integer_t<T>, types::detail::rvv_width_m1>(T(0));
1520+
auto ones = detail::broadcast<as_unsigned_integer_t<T>, A::width>(1);
1521+
auto iota = detail::rvvid(as_unsigned_integer_t<T> {});
1522+
auto upowers = detail::rvvsll(ones, iota);
1523+
auto r = __riscv_vredor(self.data.as_mask(), upowers, (typename decltype(zero)::register_type)zero, batch_bool<T, A>::size);
1524+
return detail::reduce_scalar<A, as_unsigned_integer_t<T>>(r);
1525+
}
1526+
else XSIMD_IF_CONSTEXPR((2 * 8 * sizeof(T)) == batch_bool<T, A>::size)
1527+
{
1528+
// (B) We need two rounds, one for the low part, one for the high part.
1529+
1530+
struct LowerHalf
1531+
{
1532+
static constexpr bool get(unsigned i, unsigned n) { return i < n / 2; }
1533+
};
1534+
struct UpperHalf
1535+
{
1536+
static constexpr bool get(unsigned i, unsigned n) { return i >= n / 2; }
1537+
};
1538+
1539+
// The low part is similar to the approach in (A).
1540+
const auto zero = detail::broadcast<as_unsigned_integer_t<T>, types::detail::rvv_width_m1>(T(0));
1541+
auto ones = detail::broadcast<as_unsigned_integer_t<T>, A::width>(1);
1542+
auto iota = detail::rvvid(as_unsigned_integer_t<T> {});
1543+
auto upowers = detail::rvvsll(ones, iota);
1544+
auto low_mask = self & make_batch_bool_constant<T, LowerHalf, A>();
1545+
auto r_low = __riscv_vredor(low_mask.data.as_mask(), upowers, (typename decltype(zero)::register_type)zero, batch_bool<T, A>::size);
1546+
1547+
// The high part requires to slide the upower filter to match the high mask.
1548+
upowers = detail::rvvslideup(upowers, upowers, 8 * sizeof(T));
1549+
auto high_mask = self & make_batch_bool_constant<T, UpperHalf, A>();
1550+
auto r_high = __riscv_vredor(high_mask.data.as_mask(), upowers, (typename decltype(zero)::register_type)zero, batch_bool<T, A>::size);
1551+
1552+
// Concatenate the two parts.
1553+
return (uint64_t)detail::reduce_scalar<A, as_unsigned_integer_t<T>>(r_low) | ((uint64_t)detail::reduce_scalar<A, as_unsigned_integer_t<T>>(r_high) << (8 * sizeof(T)));
1554+
}
1555+
else
1556+
{
1557+
// (C) we could generalize (B) but we already cover a lot of case now.
1558+
return mask(self, common {});
1559+
}
1560+
}
15071561
} // namespace kernel
15081562
} // namespace xsimd
15091563

include/xsimd/types/xsimd_rvv_register.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ namespace xsimd
391391
{
392392
}
393393
operator type() const noexcept { return bool_info::bitcast(value); }
394+
type as_mask() const noexcept { return (type) * this; }
394395
};
395396

396397
template <class T, size_t Width = XSIMD_RVV_BITS>

0 commit comments

Comments
 (0)