Skip to content

Commit d12b767

Browse files
Implement [u]int8 and [u]int16 matrix transpose for 128 bit registers
Related to #1054
1 parent f78c251 commit d12b767

File tree

2 files changed

+145
-1
lines changed

2 files changed

+145
-1
lines changed

Diff for: include/xsimd/arch/generic/xsimd_generic_memory.hpp

+136-1
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ namespace xsimd
627627
hi.store_aligned(buffer + real_batch::size);
628628
}
629629

630-
// store_compelx_unaligned
630+
// store_complex_unaligned
631631
template <class A, class T_out, class T_in>
632632
XSIMD_INLINE void store_complex_unaligned(std::complex<T_out>* dst, batch<std::complex<T_in>, A> const& src, requires_arch<generic>) noexcept
633633
{
@@ -665,6 +665,141 @@ namespace xsimd
665665
}
666666
}
667667

668+
// transpose
669+
template <class A, class = typename std::enable_if<batch<int16_t, A>::size == 8, void>::type>
670+
XSIMD_INLINE void transpose(batch<int16_t, A>* matrix_begin, batch<int16_t, A>* matrix_end, requires_arch<generic>) noexcept
671+
{
672+
assert((matrix_end - matrix_begin == batch<int16_t, A>::size) && "correctly sized matrix");
673+
(void)matrix_end;
674+
auto l0 = zip_lo(matrix_begin[0], matrix_begin[1]);
675+
auto l1 = zip_lo(matrix_begin[2], matrix_begin[3]);
676+
auto l2 = zip_lo(matrix_begin[4], matrix_begin[5]);
677+
auto l3 = zip_lo(matrix_begin[6], matrix_begin[7]);
678+
679+
auto l4 = zip_lo(bit_cast<batch<int32_t, A>>(l0), bit_cast<batch<int32_t, A>>(l1));
680+
auto l5 = zip_lo(bit_cast<batch<int32_t, A>>(l2), bit_cast<batch<int32_t, A>>(l3));
681+
682+
auto l6 = zip_hi(bit_cast<batch<int32_t, A>>(l0), bit_cast<batch<int32_t, A>>(l1));
683+
auto l7 = zip_hi(bit_cast<batch<int32_t, A>>(l2), bit_cast<batch<int32_t, A>>(l3));
684+
685+
auto h0 = zip_hi(matrix_begin[0], matrix_begin[1]);
686+
auto h1 = zip_hi(matrix_begin[2], matrix_begin[3]);
687+
auto h2 = zip_hi(matrix_begin[4], matrix_begin[5]);
688+
auto h3 = zip_hi(matrix_begin[6], matrix_begin[7]);
689+
690+
auto h4 = zip_lo(bit_cast<batch<int32_t, A>>(h0), bit_cast<batch<int32_t, A>>(h1));
691+
auto h5 = zip_lo(bit_cast<batch<int32_t, A>>(h2), bit_cast<batch<int32_t, A>>(h3));
692+
693+
auto h6 = zip_hi(bit_cast<batch<int32_t, A>>(h0), bit_cast<batch<int32_t, A>>(h1));
694+
auto h7 = zip_hi(bit_cast<batch<int32_t, A>>(h2), bit_cast<batch<int32_t, A>>(h3));
695+
696+
matrix_begin[0] = bit_cast<batch<int16_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(l4), bit_cast<batch<int64_t, A>>(l5)));
697+
matrix_begin[1] = bit_cast<batch<int16_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(l4), bit_cast<batch<int64_t, A>>(l5)));
698+
matrix_begin[2] = bit_cast<batch<int16_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(l6), bit_cast<batch<int64_t, A>>(l7)));
699+
matrix_begin[3] = bit_cast<batch<int16_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(l6), bit_cast<batch<int64_t, A>>(l7)));
700+
701+
matrix_begin[4] = bit_cast<batch<int16_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(h4), bit_cast<batch<int64_t, A>>(h5)));
702+
matrix_begin[5] = bit_cast<batch<int16_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(h4), bit_cast<batch<int64_t, A>>(h5)));
703+
matrix_begin[6] = bit_cast<batch<int16_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(h6), bit_cast<batch<int64_t, A>>(h7)));
704+
matrix_begin[7] = bit_cast<batch<int16_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(h6), bit_cast<batch<int64_t, A>>(h7)));
705+
}
706+
707+
template <class A>
708+
XSIMD_INLINE void transpose(batch<uint16_t, A>* matrix_begin, batch<uint16_t, A>* matrix_end, requires_arch<generic>) noexcept
709+
{
710+
transpose(reinterpret_cast<batch<int16_t, A>*>(matrix_begin), reinterpret_cast<batch<int16_t, A>*>(matrix_end), A {});
711+
}
712+
713+
template <class A, class = typename std::enable_if<batch<int8_t, A>::size == 16, void>::type>
714+
XSIMD_INLINE void transpose(batch<int8_t, A>* matrix_begin, batch<int8_t, A>* matrix_end, requires_arch<generic>) noexcept
715+
{
716+
assert((matrix_end - matrix_begin == batch<int8_t, A>::size) && "correctly sized matrix");
717+
(void)matrix_end;
718+
auto l0 = zip_lo(matrix_begin[0], matrix_begin[1]);
719+
auto l1 = zip_lo(matrix_begin[2], matrix_begin[3]);
720+
auto l2 = zip_lo(matrix_begin[4], matrix_begin[5]);
721+
auto l3 = zip_lo(matrix_begin[6], matrix_begin[7]);
722+
auto l4 = zip_lo(matrix_begin[8], matrix_begin[9]);
723+
auto l5 = zip_lo(matrix_begin[10], matrix_begin[11]);
724+
auto l6 = zip_lo(matrix_begin[12], matrix_begin[13]);
725+
auto l7 = zip_lo(matrix_begin[14], matrix_begin[15]);
726+
727+
auto h0 = zip_hi(matrix_begin[0], matrix_begin[1]);
728+
auto h1 = zip_hi(matrix_begin[2], matrix_begin[3]);
729+
auto h2 = zip_hi(matrix_begin[4], matrix_begin[5]);
730+
auto h3 = zip_hi(matrix_begin[6], matrix_begin[7]);
731+
auto h4 = zip_hi(matrix_begin[8], matrix_begin[9]);
732+
auto h5 = zip_hi(matrix_begin[10], matrix_begin[11]);
733+
auto h6 = zip_hi(matrix_begin[12], matrix_begin[13]);
734+
auto h7 = zip_hi(matrix_begin[14], matrix_begin[15]);
735+
736+
auto L0 = zip_lo(bit_cast<batch<int16_t, A>>(l0), bit_cast<batch<int16_t, A>>(l1));
737+
auto L1 = zip_lo(bit_cast<batch<int16_t, A>>(l2), bit_cast<batch<int16_t, A>>(l3));
738+
auto L2 = zip_lo(bit_cast<batch<int16_t, A>>(l4), bit_cast<batch<int16_t, A>>(l5));
739+
auto L3 = zip_lo(bit_cast<batch<int16_t, A>>(l6), bit_cast<batch<int16_t, A>>(l7));
740+
741+
auto m0 = zip_lo(bit_cast<batch<int32_t, A>>(L0), bit_cast<batch<int32_t, A>>(L1));
742+
auto m1 = zip_lo(bit_cast<batch<int32_t, A>>(L2), bit_cast<batch<int32_t, A>>(L3));
743+
auto m2 = zip_hi(bit_cast<batch<int32_t, A>>(L0), bit_cast<batch<int32_t, A>>(L1));
744+
auto m3 = zip_hi(bit_cast<batch<int32_t, A>>(L2), bit_cast<batch<int32_t, A>>(L3));
745+
746+
matrix_begin[0] = bit_cast<batch<int8_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(m0), bit_cast<batch<int64_t, A>>(m1)));
747+
matrix_begin[1] = bit_cast<batch<int8_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(m0), bit_cast<batch<int64_t, A>>(m1)));
748+
matrix_begin[2] = bit_cast<batch<int8_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(m2), bit_cast<batch<int64_t, A>>(m3)));
749+
matrix_begin[3] = bit_cast<batch<int8_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(m2), bit_cast<batch<int64_t, A>>(m3)));
750+
751+
auto L4 = zip_hi(bit_cast<batch<int16_t, A>>(l0), bit_cast<batch<int16_t, A>>(l1));
752+
auto L5 = zip_hi(bit_cast<batch<int16_t, A>>(l2), bit_cast<batch<int16_t, A>>(l3));
753+
auto L6 = zip_hi(bit_cast<batch<int16_t, A>>(l4), bit_cast<batch<int16_t, A>>(l5));
754+
auto L7 = zip_hi(bit_cast<batch<int16_t, A>>(l6), bit_cast<batch<int16_t, A>>(l7));
755+
756+
auto m4 = zip_lo(bit_cast<batch<int32_t, A>>(L4), bit_cast<batch<int32_t, A>>(L5));
757+
auto m5 = zip_lo(bit_cast<batch<int32_t, A>>(L6), bit_cast<batch<int32_t, A>>(L7));
758+
auto m6 = zip_hi(bit_cast<batch<int32_t, A>>(L4), bit_cast<batch<int32_t, A>>(L5));
759+
auto m7 = zip_hi(bit_cast<batch<int32_t, A>>(L6), bit_cast<batch<int32_t, A>>(L7));
760+
761+
matrix_begin[4] = bit_cast<batch<int8_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(m4), bit_cast<batch<int64_t, A>>(m5)));
762+
matrix_begin[5] = bit_cast<batch<int8_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(m4), bit_cast<batch<int64_t, A>>(m5)));
763+
matrix_begin[6] = bit_cast<batch<int8_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(m6), bit_cast<batch<int64_t, A>>(m7)));
764+
matrix_begin[7] = bit_cast<batch<int8_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(m6), bit_cast<batch<int64_t, A>>(m7)));
765+
766+
auto H0 = zip_lo(bit_cast<batch<int16_t, A>>(h0), bit_cast<batch<int16_t, A>>(h1));
767+
auto H1 = zip_lo(bit_cast<batch<int16_t, A>>(h2), bit_cast<batch<int16_t, A>>(h3));
768+
auto H2 = zip_lo(bit_cast<batch<int16_t, A>>(h4), bit_cast<batch<int16_t, A>>(h5));
769+
auto H3 = zip_lo(bit_cast<batch<int16_t, A>>(h6), bit_cast<batch<int16_t, A>>(h7));
770+
771+
auto M0 = zip_lo(bit_cast<batch<int32_t, A>>(H0), bit_cast<batch<int32_t, A>>(H1));
772+
auto M1 = zip_lo(bit_cast<batch<int32_t, A>>(H2), bit_cast<batch<int32_t, A>>(H3));
773+
auto M2 = zip_hi(bit_cast<batch<int32_t, A>>(H0), bit_cast<batch<int32_t, A>>(H1));
774+
auto M3 = zip_hi(bit_cast<batch<int32_t, A>>(H2), bit_cast<batch<int32_t, A>>(H3));
775+
776+
matrix_begin[8] = bit_cast<batch<int8_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(M0), bit_cast<batch<int64_t, A>>(M1)));
777+
matrix_begin[9] = bit_cast<batch<int8_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(M0), bit_cast<batch<int64_t, A>>(M1)));
778+
matrix_begin[10] = bit_cast<batch<int8_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(M2), bit_cast<batch<int64_t, A>>(M3)));
779+
matrix_begin[11] = bit_cast<batch<int8_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(M2), bit_cast<batch<int64_t, A>>(M3)));
780+
781+
auto H4 = zip_hi(bit_cast<batch<int16_t, A>>(h0), bit_cast<batch<int16_t, A>>(h1));
782+
auto H5 = zip_hi(bit_cast<batch<int16_t, A>>(h2), bit_cast<batch<int16_t, A>>(h3));
783+
auto H6 = zip_hi(bit_cast<batch<int16_t, A>>(h4), bit_cast<batch<int16_t, A>>(h5));
784+
auto H7 = zip_hi(bit_cast<batch<int16_t, A>>(h6), bit_cast<batch<int16_t, A>>(h7));
785+
786+
auto M4 = zip_lo(bit_cast<batch<int32_t, A>>(H4), bit_cast<batch<int32_t, A>>(H5));
787+
auto M5 = zip_lo(bit_cast<batch<int32_t, A>>(H6), bit_cast<batch<int32_t, A>>(H7));
788+
auto M6 = zip_hi(bit_cast<batch<int32_t, A>>(H4), bit_cast<batch<int32_t, A>>(H5));
789+
auto M7 = zip_hi(bit_cast<batch<int32_t, A>>(H6), bit_cast<batch<int32_t, A>>(H7));
790+
791+
matrix_begin[12] = bit_cast<batch<int8_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(M4), bit_cast<batch<int64_t, A>>(M5)));
792+
matrix_begin[13] = bit_cast<batch<int8_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(M4), bit_cast<batch<int64_t, A>>(M5)));
793+
matrix_begin[14] = bit_cast<batch<int8_t, A>>(zip_lo(bit_cast<batch<int64_t, A>>(M6), bit_cast<batch<int64_t, A>>(M7)));
794+
matrix_begin[15] = bit_cast<batch<int8_t, A>>(zip_hi(bit_cast<batch<int64_t, A>>(M6), bit_cast<batch<int64_t, A>>(M7)));
795+
}
796+
797+
template <class A>
798+
XSIMD_INLINE void transpose(batch<uint8_t, A>* matrix_begin, batch<uint8_t, A>* matrix_end, requires_arch<generic>) noexcept
799+
{
800+
transpose(reinterpret_cast<batch<int8_t, A>*>(matrix_begin), reinterpret_cast<batch<int8_t, A>*>(matrix_end), A {});
801+
}
802+
668803
}
669804

670805
}

Diff for: test/test_shuffle.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -723,4 +723,13 @@ TEST_CASE_TEMPLATE("[shuffle]", B, BATCH_FLOAT_TYPES, xsimd::batch<uint32_t>, xs
723723
}
724724
}
725725

726+
TEST_CASE_TEMPLATE("[small integer transpose]", B, xsimd::batch<uint16_t>, xsimd::batch<int16_t>, xsimd::batch<uint8_t>, xsimd::batch<int8_t>)
727+
{
728+
shuffle_test<B> Test;
729+
SUBCASE("transpose")
730+
{
731+
Test.transpose();
732+
}
733+
}
734+
726735
#endif

0 commit comments

Comments
 (0)