Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Unaligned access in test/unit/gemm/threadblock/batched_gemv.cu #2003

Open
Artem-B opened this issue Dec 19, 2024 · 1 comment
Open
Labels
? - Needs Triage bug Something isn't working

Comments

@Artem-B
Copy link
Contributor

Artem-B commented Dec 19, 2024

Describe the bug
When tests are built w/o optimizations, test/unit/gemm/threadblock/batched_gemv.cu crashes with unaligned access exception.

It's not clear whether it's the test which does not set the correct alignment on the gemm parameters, or cutlass itself assuming specific alignment where it should not have. Forcing alignment on the input avoids the issue, but I'm not sure if that's a fix or just a workaround.

[ RUN      ] SM50_batched_gemv_threadblock.16x1x17x64_rcr_fp32_fp32_1N_4K
third_party/gpus/cutlass/test/unit/gemm/threadblock/batched_gemv.cu.cc:216: Failure
Expected equality of these values:
  result
    Which is: misaligned address
  cudaSuccess
    Which is: no error
 kernel error: misaligned address

libc++abi: terminating due to uncaught exception of type cutlass::cuda_exception: std::exception
[1]    46252 IOT instruction  blaze-bin/third_party/gpus/cutlass/test/gemm/threadblock/batched_gemv

Steps/Code to reproduce bug

Build cutlass with clang w/o optimizations. Run cutlass tests, observe some of them failing.

Expected behavior

The root cause is cutlass implicitly relying on everything being inlined and some of the intermediate operations optimized away.
When they are not, it exposes the fact that cutlass code has a lot of places where it passes an opaque pointer around, and assumes that it is aligned on a certain boundary. This assumption is not always guaranteed by the callers, in this case the batched_gemv.cu test.

E.g. accum does not have any alignment specified here:

typename Gemv::FragmentC accum;
accum.clear();
// Compute threadblock-scoped matrix multiply-add
gemv(problem_size, accum, iterator_A, iterator_B, accum);

... and eventually the pointer to it is used to attempt loading a 8-byte word and that causes a misaligned access exception here:

char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
if (address_iterator_.valid()) {
*access_ptr = frag_ptr[idx];
}

Stack trace:

Thread 1 "batched_gemv" received signal CUDA_EXCEPTION_6, Warp Misaligned Address.
[Switching focus to CUDA kernel 0, grid 9, block (0,0,0), thread (0,3,0), device 0, sm 0, warp 3, lane 0]
0x00007ffd19ec8690 in cutlass::transform::threadblock::PredicatedTileIterator<cutlass::PitchLinearShape<64, 1>, float, cutlass::layout::PitchLinear, 1, cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous<cutlass::PitchLinearShape<64, 1>, 32, 2>, 2, false, cutlass::layout::NoPermute>::store_with_byte_offset (this=0x7ffff0fffc08, frag=..., byte_offset=0) at /proc/self/cwd/./third_party/gpus/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h:384
384                 *access_ptr = frag_ptr[idx];
(cuda-gdb) bt
#0  0x00007ffd19ec8690 in cutlass::transform::threadblock::PredicatedTileIterator<cutlass::PitchLinearShape<64, 1>, float, cutlass::layout::PitchLinear, 1, cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous<cutlass::PitchLinearShape<64, 1>, 32, 2>, 2, false, cutlass::layout::NoPermute>::store_with_byte_offset (this=0x7ffff0fffc08, frag=..., byte_offset=0) at /proc/self/cwd/./third_party/gpus/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h:384
#1  cutlass::transform::threadblock::PredicatedTileIterator<cutlass::PitchLinearShape<64, 1>, float, cutlass::layout::PitchLinear, 1, cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous<cutlass::PitchLinearShape<64, 1>, 32, 2>, 2, false, cutlass::layout::NoPermute>::store_with_pointer_offset (this=0x7ffff0fffc08, frag=..., pointer_offset=0) at /proc/self/cwd/./third_party/gpus/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h:362
#2  cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, 64>, float, cutlass::layout::RowMajor, 0, cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous<cutlass::PitchLinearShape<64, 1>, 32, 2>, 2, false, cutlass::layout::NoPermute>::store_with_pointer_offset (this=0x7ffff0fffc08, frag=..., pointer_offset=0) at /proc/self/cwd/./third_party/gpus/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h:816
#3  cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, 64>, float, cutlass::layout::RowMajor, 0, cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous<cutlass::PitchLinearShape<64, 1>, 32, 2>, 2, false, cutlass::layout::NoPermute>::store<<<(1,1,1),(32,4,1)>>> (this=0x7ffff0fffc08, frag=...) at /proc/self/cwd/./third_party/gpus/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h:828
#4  test::gemm::threadblock::batched_gemv_threadblock_test_kernel<cutlass::gemm::threadblock::Gemv<cutlass::gemm::threadblock::DefaultGemvCore<cutlass::gemm::GemmShape<1, 64, 2>, cutlass::gemm::GemmShape<1, 2, 2>, float, cutlass::layout::RowMajor, float, cutlass::layout::ColumnMajor, float, cutlass::layout::RowMajor> >, long, cutlass::TensorRef<float, cutlass::layout::RowMajor>, cutlass::TensorRef<float, cutlass::layout::ColumnMajor>, cutlass::TensorRef<float, cutlass::layout::RowMajor> >
   <<<(1,1,1),(32,4,1)>>> (warning: Could not find DWO CU blaze-out/k8-dbg/bin/third_party/gpus/cutlass/test/_objs/gemm/threadblock/batched_gemv_lib.host.0/batched_gemv.cu.pic.dwo(0xa1063e0a09a65a6b) referenced by CU at offset 0x4ec [in module /google/obj/workspace/59020db8998c499a49126ed0daf698aa034958bc800d56c31bc15c93b4d9bbce/ecad6e51-6ea9-4661-8eb4-75ae4e6417cc/blaze-out/k8-dbg/bin/third_party/gpus/cutlass/test/gemm/threadblock/batched_gemv]
problem_size=<incomplete type>, stride_a=64, stride_b=4096, stride_c=64, ref_A=..., ref_B=..., ref_C=...) at third_party/gpus/cutlass/test/unit/gemm/threadblock/batched_gemv.cu.cc:99

Environment details (please complete the following information):

  • cutlass 3.4.1 bbe579a
  • clang @ HEAD as the compiler
  • custom build with bazel

Possible fix or workaround:

typename Gemv::FragmentC accum;

@@ -87,7 +87,7 @@ template <typename Gemv, typename LongIn

   Gemv gemv;

-  typename Gemv::FragmentC accum;
+  typename Gemv::FragmentC accum alignas(16);
   accum.clear();

   // Compute threadblock-scoped matrix multiply-add
@Artem-B Artem-B added ? - Needs Triage bug Something isn't working labels Dec 19, 2024
@Artem-B
Copy link
Contributor Author

Artem-B commented Dec 20, 2024

Another source of unaligned inputs is in cutlass headers:

diff --git a/cutlass/include/cutlass/gemm/threadblock/gemv.h b/cutlass/include/cutlass/gemm/threadblock/gemv.h
--- a/cutlass/include/cutlass/gemm/threadblock/gemv.h
+++ b/cutlass/include/cutlass/gemm/threadblock/gemv.h
@@ -98,8 +98,8 @@ public:
     // Prologue
     //

-    FragmentA frag_A;
-    FragmentB frag_B;
+    FragmentA frag_A alignas(16);
+    FragmentB frag_B alignas(16);
     frag_A.clear();
     frag_B.clear();

I suspect there may be more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant